|
- # this is a copy paste implementation of github.com/krux02/ast_pattern_matching
- # Please provide bugfixes upstream first before adding them here.
- import macros, strutils, tables
- export macros
- when isMainModule:
- template debug(args: varargs[untyped]): untyped =
- echo args
- else:
- template debug(args: varargs[untyped]): untyped =
- discard
- const
- nnkIntLiterals* = nnkCharLit..nnkUInt64Lit
- nnkStringLiterals* = nnkStrLit..nnkTripleStrLit
- nnkFloatLiterals* = nnkFloatLit..nnkFloat64Lit
- proc newLit[T: enum](arg: T): NimNode =
- newIdentNode($arg)
- proc newLit[T](arg: set[T]): NimNode =
- ## does not work for the empty sets
- result = nnkCurly.newTree
- for x in arg:
- result.add newLit(x)
- type SomeFloat = float | float32 | float64
- proc len[T](arg: set[T]): int = card(arg)
- type
- MatchingErrorKind* = enum
- NoError
- WrongKindLength
- WrongKindValue
- WrongIdent
- WrongCustomCondition
- MatchingError = object
- node*: NimNode
- expectedKind*: set[NimNodeKind]
- case kind*: MatchingErrorKind
- of NoError:
- discard
- of WrongKindLength:
- expectedLength*: int
- of WrongKindValue:
- expectedValue*: NimNode
- of WrongIdent, WrongCustomCondition:
- strVal*: string
- proc `$`*(arg: MatchingError): string =
- let n = arg.node
- case arg.kind
- of NoError:
- "no error"
- of WrongKindLength:
- let k = arg.expectedKind
- let l = arg.expectedLength
- var msg = "expected "
- if k.len == 0:
- msg.add "any node"
- elif k.len == 1:
- for el in k: # only one element but there is no index op for sets
- msg.add $el
- else:
- msg.add "a node in" & $k
- if l >= 0:
- msg.add " with " & $l & " child(ren)"
- msg.add ", but got " & $n.kind
- if l >= 0:
- msg.add " with " & $n.len & " child(ren)"
- msg
- of WrongKindValue:
- let k = $arg.expectedKind
- let v = arg.expectedValue.repr
- var msg = "expected " & k & " with value " & v & " but got " & n.lispRepr
- if n.kind in {nnkOpenSymChoice, nnkClosedSymChoice}:
- msg = msg & " (a sym-choice does not have a strVal member, maybe you should match with `ident`)"
- msg
- of WrongIdent:
- let prefix = "expected ident `" & arg.strVal & "` but got "
- if n.kind in {nnkIdent, nnkSym, nnkOpenSymChoice, nnkClosedSymChoice}:
- prefix & "`" & n.strVal & "`"
- else:
- prefix & $n.kind & " with " & $n.len & " child(ren)"
- of WrongCustomCondition:
- "custom condition check failed: " & arg.strVal
- proc failWithMatchingError*(arg: MatchingError): void {.compileTime, noReturn.} =
- error($arg, arg.node)
- proc expectValue(arg: NimNode; value: SomeInteger): void {.compileTime.} =
- arg.expectKind nnkLiterals
- if arg.intVal != int(value):
- error("expected value " & $value & " but got " & arg.repr, arg)
- proc expectValue(arg: NimNode; value: SomeFloat): void {.compileTime.} =
- arg.expectKind nnkLiterals
- if arg.floatVal != float(value):
- error("expected value " & $value & " but got " & arg.repr, arg)
- proc expectValue(arg: NimNode; value: string): void {.compileTime.} =
- arg.expectKind nnkLiterals
- if arg.strVal != value:
- error("expected value " & value & " but got " & arg.repr, arg)
- proc expectValue[T](arg: NimNode; value: pointer): void {.compileTime.} =
- arg.expectKind nnkLiterals
- if value != nil:
- error("Expect Value for pointers works only on `nil` when the argument is a pointer.")
- arg.expectKind nnkNilLit
- proc expectIdent(arg: NimNode; strVal: string): void {.compileTime.} =
- if not arg.eqIdent(strVal):
- error("Expect ident `" & strVal & "` but got " & arg.repr)
- proc matchLengthKind*(arg: NimNode; kind: set[NimNodeKind]; length: int): MatchingError {.compileTime.} =
- let kindFail = not(kind.card == 0 or arg.kind in kind)
- let lengthFail = not(length < 0 or length == arg.len)
- if kindFail or lengthFail:
- result.node = arg
- result.kind = WrongKindLength
- result.expectedLength = length
- result.expectedKind = kind
- proc matchLengthKind*(arg: NimNode; kind: NimNodeKind; length: int): MatchingError {.compileTime.} =
- matchLengthKind(arg, {kind}, length)
- proc matchValue(arg: NimNode; kind: set[NimNodeKind]; value: SomeInteger): MatchingError {.compileTime.} =
- let kindFail = not(kind.card == 0 or arg.kind in kind)
- let valueFail = arg.intVal != int(value)
- if kindFail or valueFail:
- result.node = arg
- result.kind = WrongKindValue
- result.expectedKind = kind
- result.expectedValue = newLit(value)
- proc matchValue(arg: NimNode; kind: NimNodeKind; value: SomeInteger): MatchingError {.compileTime.} =
- matchValue(arg, {kind}, value)
- proc matchValue(arg: NimNode; kind: set[NimNodeKind]; value: SomeFloat): MatchingError {.compileTime.} =
- let kindFail = not(kind.card == 0 or arg.kind in kind)
- let valueFail = arg.floatVal != float(value)
- if kindFail or valueFail:
- result.node = arg
- result.kind = WrongKindValue
- result.expectedKind = kind
- result.expectedValue = newLit(value)
- proc matchValue(arg: NimNode; kind: NimNodeKind; value: SomeFloat): MatchingError {.compileTime.} =
- matchValue(arg, {kind}, value)
- const nnkStrValKinds = {nnkStrLit, nnkRStrLit, nnkTripleStrLit, nnkIdent, nnkSym}
- proc matchValue(arg: NimNode; kind: set[NimNodeKind]; value: string): MatchingError {.compileTime.} =
- # if kind * nnkStringLiterals TODO do something that ensures that here is only checked for string literals
- let kindFail = not(kind.card == 0 or arg.kind in kind)
- let valueFail =
- if kind.card == 0:
- false
- else:
- arg.kind notin (kind * nnkStrValKinds) or arg.strVal != value
- if kindFail or valueFail:
- result.node = arg
- result.kind = WrongKindValue
- result.expectedKind = kind
- result.expectedValue = newLit(value)
- proc matchValue(arg: NimNode; kind: NimNodeKind; value: string): MatchingError {.compileTime.} =
- matchValue(arg, {kind}, value)
- proc matchValue[T](arg: NimNode; value: pointer): MatchingError {.compileTime.} =
- if value != nil:
- error("Expect Value for pointers works only on `nil` when the argument is a pointer.")
- arg.matchLengthKind(nnkNilLit, -1)
- proc matchIdent*(arg:NimNode; value: string): MatchingError =
- if not arg.eqIdent(value):
- result.node = arg
- result.kind = Wrongident
- result.strVal = value
- proc checkCustomExpr*(arg: NimNode; cond: bool, exprstr: string): MatchingError =
- if not cond:
- result.node = arg
- result.kind = WrongCustomCondition
- result.strVal = exprstr
- static:
- var literals: array[19, NimNode]
- var i = 0
- for litKind in nnkLiterals:
- literals[i] = ident($litKind)
- i += 1
- var nameToKind = newTable[string, NimNodeKind]()
- for kind in NimNodeKind:
- nameToKind[ ($kind)[3..^1] ] = kind
- let identifierKinds = newLit({nnkSym, nnkIdent, nnkOpenSymChoice, nnkClosedSymChoice})
- proc generateMatchingCode(astSym: NimNode, pattern: NimNode, depth: int, blockLabel, errorSym, localsArraySym: NimNode; dest: NimNode): int =
- ## return the number of indices used in the array for local variables.
- var currentLocalIndex = 0
- proc nodeVisiting(astSym: NimNode, pattern: NimNode, depth: int): void =
- let ind = " ".repeat(depth) # indentation
- proc genMatchLogic(matchProc, argSym1, argSym2: NimNode): void =
- dest.add quote do:
- `errorSym` = `astSym`.`matchProc`(`argSym1`, `argSym2`)
- if `errorSym`.kind != NoError:
- break `blockLabel`
- proc genIdentMatchLogic(identValueLit: NimNode): void =
- dest.add quote do:
- `errorSym` = `astSym`.matchIdent(`identValueLit`)
- if `errorSym`.kind != NoError:
- break `blockLabel`
- proc genCustomMatchLogic(conditionExpr: NimNode): void =
- let exprStr = newLit(conditionExpr.repr)
- dest.add quote do:
- `errorSym` = `astSym`.checkCustomExpr(`conditionExpr`, `exprStr`)
- if `errorSym`.kind != NoError:
- break `blockLabel`
- # proc handleKindMatching(kindExpr: NimNode): void =
- # if kindExpr.eqIdent("_"):
- # # this is the wildcand that matches any kind
- # return
- # else:
- # genMatchLogic(bindSym"matchKind", kindExpr)
- # generate recursively a matching expression
- if pattern.kind == nnkCall:
- pattern.expectMinLen(1)
- debug ind, pattern[0].repr, "("
- let kindSet = if pattern[0].eqIdent("_"): nnkCurly.newTree else: pattern[0]
- # handleKindMatching(pattern[0])
- if pattern.len == 2 and pattern[1].kind == nnkExprEqExpr:
- if pattern[1][1].kind in nnkStringLiterals:
- pattern[1][0].expectIdent("strVal")
- elif pattern[1][1].kind in nnkIntLiterals:
- pattern[1][0].expectIdent("intVal")
- elif pattern[1][1].kind in nnkFloatLiterals:
- pattern[1][0].expectIdent("floatVal")
- genMatchLogic(bindSym"matchValue", kindSet, pattern[1][1])
- else:
- let lengthLit = newLit(pattern.len - 1)
- genMatchLogic(bindSym"matchLengthKind", kindSet, lengthLit)
- for i in 1 ..< pattern.len:
- let childSym = nnkBracketExpr.newTree(localsArraySym, newLit(currentLocalIndex))
- currentLocalIndex += 1
- let indexLit = newLit(i - 1)
- dest.add quote do:
- `childSym` = `astSym`[`indexLit`]
- nodeVisiting(childSym, pattern[i], depth + 1)
- debug ind, ")"
- elif pattern.kind == nnkCallStrLit and pattern[0].eqIdent("ident"):
- genIdentMatchLogic(pattern[1])
- elif pattern.kind == nnkPar and pattern.len == 1:
- nodeVisiting(astSym, pattern[0], depth)
- elif pattern.kind == nnkPrefix:
- error("prefix patterns not implemented", pattern)
- elif pattern.kind == nnkAccQuoted:
- debug ind, pattern.repr
- let matchedExpr = pattern[0]
- matchedExpr.expectKind nnkIdent
- dest.add quote do:
- let `matchedExpr` = `astSym`
- elif pattern.kind == nnkInfix and pattern[0].eqIdent("@"):
- pattern[1].expectKind nnkAccQuoted
- let matchedExpr = pattern[1][0]
- matchedExpr.expectKind nnkIdent
- dest.add quote do:
- let `matchedExpr` = `astSym`
- debug ind, pattern[1].repr, " = "
- nodeVisiting(matchedExpr, pattern[2], depth + 1)
- elif pattern.kind == nnkInfix and pattern[0].eqIdent("|="):
- nodeVisiting(astSym, pattern[1], depth + 1)
- genCustomMatchLogic(pattern[2])
- elif pattern.kind in nnkCallKinds:
- error("only boring call syntax allowed, this is " & $pattern.kind & ".", pattern)
- elif pattern.kind in nnkLiterals:
- genMatchLogic(bindSym"matchValue", nnkCurly.newTree, pattern)
- elif not pattern.eqIdent("_"):
- # When it is not one of the other branches, it is simply treated
- # as an expression for the node kind, without checking child
- # nodes.
- debug ind, pattern.repr
- genMatchLogic(bindSym"matchLengthKind", pattern, newLit(-1))
- nodeVisiting(astSym, pattern, depth)
- return currentLocalIndex
- macro matchAst*(astExpr: NimNode; args: varargs[untyped]): untyped =
- let astSym = genSym(nskLet, "ast")
- let beginBranches = if args[0].kind == nnkIdent: 1 else: 0
- let endBranches = if args[^1].kind == nnkElse: args.len - 1 else: args.len
- for i in beginBranches ..< endBranches:
- args[i].expectKind nnkOfBranch
- let outerErrorSym: NimNode =
- if beginBranches == 1:
- args[0].expectKind nnkIdent
- args[0]
- else:
- nil
- let elseBranch: NimNode =
- if endBranches == args.len - 1:
- args[^1].expectKind(nnkElse)
- args[^1][0]
- else:
- nil
- let outerBlockLabel = genSym(nskLabel, "matchingSection")
- let outerStmtList = newStmtList()
- let errorSymbols = nnkBracket.newTree
- ## the vm only allows 255 local variables. This sucks a lot and I
- ## have to work around it. So instead of creating a lot of local
- ## variables, I just create one array of local variables. This is
- ## just annoying.
- let localsArraySym = genSym(nskVar, "locals")
- var localsArrayLen: int = 0
- for i in beginBranches ..< endBranches:
- let ofBranch = args[i]
- ofBranch.expectKind(nnkOfBranch)
- ofBranch.expectLen(2)
- let pattern = ofBranch[0]
- let code = ofBranch[1]
- code.expectKind nnkStmtList
- let stmtList = newStmtList()
- let blockLabel = genSym(nskLabel, "matchingBranch")
- let errorSym = genSym(nskVar, "branchError")
- errorSymbols.add errorSym
- let numLocalsUsed = generateMatchingCode(astSym, pattern, 0, blockLabel, errorSym, localsArraySym, stmtList)
- localsArrayLen = max(localsArrayLen, numLocalsUsed)
- stmtList.add code
- # maybe there is a better mechanism disable errors for statement after return
- if code[^1].kind != nnkReturnStmt:
- stmtList.add nnkBreakStmt.newTree(outerBlockLabel)
- outerStmtList.add quote do:
- var `errorSym`: MatchingError
- block `blockLabel`:
- `stmtList`
- if elseBranch != nil:
- if outerErrorSym != nil:
- outerStmtList.add quote do:
- let `outerErrorSym` = @`errorSymbols`
- `elseBranch`
- else:
- outerStmtList.add elseBranch
- else:
- if errorSymbols.len == 1:
- # there is only one of branch and no else branch
- # the error message can be very precise here.
- let errorSym = errorSymbols[0]
- outerStmtList.add quote do:
- failWithMatchingError(`errorSym`)
- else:
- var patterns: string = ""
- for i in beginBranches ..< endBranches:
- let ofBranch = args[i]
- let pattern = ofBranch[0]
- patterns.add pattern.repr
- patterns.add "\n"
- let patternsLit = newLit(patterns)
- outerStmtList.add quote do:
- error("Ast pattern mismatch: got " & `astSym`.lispRepr & "\nbut expected one of:\n" & `patternsLit`, `astSym`)
- let lengthLit = newLit(localsArrayLen)
- result = quote do:
- block `outerBlockLabel`:
- let `astSym` = `astExpr`
- var `localsArraySym`: array[`lengthLit`, NimNode]
- `outerStmtList`
- debug result.repr
- proc recursiveNodeVisiting*(arg: NimNode, callback: proc(arg: NimNode): bool) =
- ## if `callback` returns true, visitor continues to visit the
- ## children of `arg` otherwise it stops.
- if callback(arg):
- for child in arg:
- recursiveNodeVisiting(child, callback)
- macro matchAstRecursive*(ast: NimNode; args: varargs[untyped]): untyped =
- # Does not recurse further on matched nodes.
- if args[^1].kind == nnkElse:
- error("Recursive matching with an else branch is pointless.", args[^1])
- let visitor = genSym(nskProc, "visitor")
- let visitorArg = genSym(nskParam, "arg")
- let visitorStmtList = newStmtList()
- let matchingSection = genSym(nskLabel, "matchingSection")
- let localsArraySym = genSym(nskVar, "locals")
- let branchError = genSym(nskVar, "branchError")
- var localsArrayLen = 0
- for ofBranch in args:
- ofBranch.expectKind(nnkOfBranch)
- ofBranch.expectLen(2)
- let pattern = ofBranch[0]
- let code = ofBranch[1]
- code.expectkind(nnkStmtList)
- let stmtList = newStmtList()
- let matchingBranch = genSym(nskLabel, "matchingBranch")
- let numLocalsUsed = generateMatchingCode(visitorArg, pattern, 0, matchingBranch, branchError, localsArraySym, stmtList)
- localsArrayLen = max(localsArrayLen, numLocalsUsed)
- stmtList.add code
- stmtList.add nnkBreakStmt.newTree(matchingSection)
- visitorStmtList.add quote do:
- `branchError`.kind = NoError
- block `matchingBranch`:
- `stmtList`
- let resultIdent = ident"result"
- let visitingProc = bindSym"recursiveNodeVisiting"
- let lengthLit = newLit(localsArrayLen)
- result = quote do:
- proc `visitor`(`visitorArg`: NimNode): bool =
- block `matchingSection`:
- var `localsArraySym`: array[`lengthLit`, NimNode]
- var `branchError`: MatchingError
- `visitorStmtList`
- `resultIdent` = true
- `visitingProc`(`ast`, `visitor`)
- debug result.repr
- ################################################################################
- ################################# Example Code #################################
- ################################################################################
- when isMainModule:
- static:
- let mykinds = {nnkIdent, nnkCall}
- macro foo(arg: untyped): untyped =
- matchAst(arg, matchError):
- of nnkStmtList(nnkIdent, nnkIdent, nnkIdent):
- echo(88*88+33*33)
- of nnkStmtList(
- _(
- nnkIdentDefs(
- ident"a",
- nnkEmpty, nnkIntLit(intVal = 123)
- )
- ),
- _,
- nnkForStmt(
- nnkIdent(strVal = "i"),
- nnkInfix,
- `mysym` @ nnkStmtList
- )
- ):
- echo "The AST did match!!!"
- echo "The matched sub tree is the following:"
- echo mysym.lispRepr
- #else:
- # echo "sadly the AST did not match :("
- # echo arg.treeRepr
- # failWithMatchingError(matchError[1])
- foo:
- let a = 123
- let b = 342
- for i in a ..< b:
- echo "Hallo", i
- static:
- var ast = quote do:
- type
- A[T: static[int]] = object
- ast = ast[0]
- ast.matchAst(err): # this is a sub ast for this a findAst or something like that is useful
- of nnkTypeDef(_, nnkGenericParams( nnkIdentDefs( nnkIdent(strVal = "T"), `staticTy`, nnkEmpty )), _):
- echo "`", staticTy.repr, "` used to be of nnkStaticTy, now it is ", staticTy.kind, " with ", staticTy[0].repr
- ast = quote do:
- if cond1: expr1 elif cond2: expr2 else: expr3
- ast.matchAst:
- of {nnkIfExpr, nnkIfStmt}(
- {nnkElifExpr, nnkElifBranch}(`cond1`, `expr1`),
- {nnkElifExpr, nnkElifBranch}(`cond2`, `expr2`),
- {nnkElseExpr, nnkElse}(`expr3`)
- ):
- echo "ok"
- let ast2 = nnkStmtList.newTree( newLit(1) )
- ast2.matchAst:
- of nnkIntLit( 1 ):
- echo "fail"
- of nnkStmtList( 1 ):
- echo "ok"
- ast = bindSym"[]"
- ast.matchAst(errors):
- of nnkClosedSymChoice(strVal = "[]"):
- echo "fail, this is the wrong syntax, a sym choice does not have a `strVal` member."
- of ident"[]":
- echo "ok"
- const myConst = 123
- ast = newLit(123)
- ast.matchAst:
- of _(intVal = myConst):
- echo "ok"
- macro testRecCase(ast: untyped): untyped =
- ast.matchAstRecursive:
- of nnkIdentDefs(`a`,`b`,`c`):
- echo "got ident defs a: ", a.repr, " b: ", b.repr, " c: ", c.repr
- of ident"m":
- echo "got the ident m"
- testRecCase:
- type Obj[T] = object {.inheritable.}
- name: string
- case isFat: bool
- of true:
- m: array[100_000, T]
- of false:
- m: array[10, T]
- macro testIfCondition(ast: untyped): untyped =
- let literals = nnkBracket.newTree
- ast.matchAstRecursive:
- of `intLit` @ nnkIntLit |= intLit.intVal > 5:
- literals.add intLit
- let literals2 = quote do:
- [6,7,8,9]
- doAssert literals2 == literals
- testIfCondition([1,6,2,7,3,8,4,9,5,0,"123"])
|