patterns.nim 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. #
  2. #
  3. # The Nim Compiler
  4. # (c) Copyright 2012 Andreas Rumpf
  5. #
  6. # See the file "copying.txt", included in this
  7. # distribution, for details about the copyright.
  8. #
  9. ## This module implements the pattern matching features for term rewriting
  10. ## macro support.
  11. import
  12. ast, types, semdata, sigmatch, idents, aliases, parampatterns, trees
  13. type
  14. TPatternContext = object
  15. owner: PSym
  16. mapping: seq[PNode] # maps formal parameters to nodes
  17. formals: int
  18. c: PContext
  19. subMatch: bool # subnode matches are special
  20. mappingIsFull: bool
  21. PPatternContext = var TPatternContext
  22. proc getLazy(c: PPatternContext, sym: PSym): PNode =
  23. if c.mappingIsFull:
  24. result = c.mapping[sym.position]
  25. proc putLazy(c: PPatternContext, sym: PSym, n: PNode) =
  26. if not c.mappingIsFull:
  27. newSeq(c.mapping, c.formals)
  28. c.mappingIsFull = true
  29. c.mapping[sym.position] = n
  30. proc matches(c: PPatternContext, p, n: PNode): bool
  31. proc canonKind(n: PNode): TNodeKind =
  32. ## nodekind canonicalization for pattern matching
  33. result = n.kind
  34. case result
  35. of nkCallKinds: result = nkCall
  36. of nkStrLit..nkTripleStrLit: result = nkStrLit
  37. of nkFastAsgn: result = nkAsgn
  38. else: discard
  39. proc sameKinds(a, b: PNode): bool {.inline.} =
  40. result = a.kind == b.kind or a.canonKind == b.canonKind
  41. proc sameTrees*(a, b: PNode): bool =
  42. if sameKinds(a, b):
  43. case a.kind
  44. of nkSym: result = a.sym == b.sym
  45. of nkIdent: result = a.ident.id == b.ident.id
  46. of nkCharLit..nkInt64Lit: result = a.intVal == b.intVal
  47. of nkFloatLit..nkFloat64Lit: result = a.floatVal == b.floatVal
  48. of nkStrLit..nkTripleStrLit: result = a.strVal == b.strVal
  49. of nkEmpty, nkNilLit: result = true
  50. of nkType: result = sameTypeOrNil(a.typ, b.typ)
  51. else:
  52. if a.len == b.len:
  53. for i in 0..<a.len:
  54. if not sameTrees(a[i], b[i]): return
  55. result = true
  56. proc inSymChoice(sc, x: PNode): bool =
  57. if sc.kind == nkClosedSymChoice:
  58. for i in 0..<sc.len:
  59. if sc[i].sym == x.sym: return true
  60. elif sc.kind == nkOpenSymChoice:
  61. # same name suffices for open sym choices!
  62. result = sc[0].sym.name.id == x.sym.name.id
  63. proc checkTypes(c: PPatternContext, p: PSym, n: PNode): bool =
  64. # check param constraints first here as this is quite optimized:
  65. if p.constraint != nil:
  66. result = matchNodeKinds(p.constraint, n)
  67. if not result: return
  68. if isNil(n.typ):
  69. result = p.typ.kind in {tyVoid, tyTyped}
  70. else:
  71. result = sigmatch.argtypeMatches(c.c, p.typ, n.typ, fromHlo = true)
  72. proc isPatternParam(c: PPatternContext, p: PNode): bool {.inline.} =
  73. result = p.kind == nkSym and p.sym.kind == skParam and p.sym.owner == c.owner
  74. proc matchChoice(c: PPatternContext, p, n: PNode): bool =
  75. for i in 1..<p.len:
  76. if matches(c, p[i], n): return true
  77. proc bindOrCheck(c: PPatternContext, param: PSym, n: PNode): bool =
  78. var pp = getLazy(c, param)
  79. if pp != nil:
  80. # check if we got the same pattern (already unified):
  81. result = sameTrees(pp, n) #matches(c, pp, n)
  82. elif n.kind == nkArgList or checkTypes(c, param, n):
  83. putLazy(c, param, n)
  84. result = true
  85. proc gather(c: PPatternContext, param: PSym, n: PNode) =
  86. var pp = getLazy(c, param)
  87. if pp != nil and pp.kind == nkArgList:
  88. pp.add(n)
  89. else:
  90. pp = newNodeI(nkArgList, n.info, 1)
  91. pp[0] = n
  92. putLazy(c, param, pp)
  93. proc matchNested(c: PPatternContext, p, n: PNode, rpn: bool): bool =
  94. # match ``op * param`` or ``op *| param``
  95. proc matchStarAux(c: PPatternContext, op, n, arglist: PNode,
  96. rpn: bool): bool =
  97. result = true
  98. if n.kind in nkCallKinds and matches(c, op[1], n[0]):
  99. for i in 1..<n.len:
  100. if not matchStarAux(c, op, n[i], arglist, rpn): return false
  101. if rpn: arglist.add(n[0])
  102. elif n.kind == nkHiddenStdConv and n[1].kind == nkBracket:
  103. let n = n[1]
  104. for i in 0..<n.len:
  105. if not matchStarAux(c, op, n[i], arglist, rpn): return false
  106. elif checkTypes(c, p[2].sym, n):
  107. arglist.add(n)
  108. else:
  109. result = false
  110. if n.kind notin nkCallKinds: return false
  111. if matches(c, p[1], n[0]):
  112. var arglist = newNodeI(nkArgList, n.info)
  113. if matchStarAux(c, p, n, arglist, rpn):
  114. result = bindOrCheck(c, p[2].sym, arglist)
  115. proc matches(c: PPatternContext, p, n: PNode): bool =
  116. let n = skipHidden(n)
  117. if nfNoRewrite in n.flags:
  118. result = false
  119. elif isPatternParam(c, p):
  120. result = bindOrCheck(c, p.sym, n)
  121. elif n.kind == nkSym and p.kind == nkIdent:
  122. result = p.ident.id == n.sym.name.id
  123. elif n.kind == nkSym and inSymChoice(p, n):
  124. result = true
  125. elif n.kind == nkSym and n.sym.kind == skConst:
  126. # try both:
  127. if p.kind == nkSym: result = p.sym == n.sym
  128. elif matches(c, p, n.sym.ast): result = true
  129. elif p.kind == nkPattern:
  130. # pattern operators: | *
  131. let opr = p[0].ident.s
  132. case opr
  133. of "|": result = matchChoice(c, p, n)
  134. of "*": result = matchNested(c, p, n, rpn=false)
  135. of "**": result = matchNested(c, p, n, rpn=true)
  136. of "~": result = not matches(c, p[1], n)
  137. else: doAssert(false, "invalid pattern")
  138. # template {add(a, `&` * b)}(a: string{noalias}, b: varargs[string]) =
  139. # a.add(b)
  140. elif p.kind == nkCurlyExpr:
  141. if p[1].kind == nkPrefix:
  142. if matches(c, p[0], n):
  143. gather(c, p[1][1].sym, n)
  144. result = true
  145. else:
  146. assert isPatternParam(c, p[1])
  147. if matches(c, p[0], n):
  148. result = bindOrCheck(c, p[1].sym, n)
  149. elif sameKinds(p, n):
  150. case p.kind
  151. of nkSym: result = p.sym == n.sym
  152. of nkIdent: result = p.ident.id == n.ident.id
  153. of nkCharLit..nkInt64Lit: result = p.intVal == n.intVal
  154. of nkFloatLit..nkFloat64Lit: result = p.floatVal == n.floatVal
  155. of nkStrLit..nkTripleStrLit: result = p.strVal == n.strVal
  156. of nkEmpty, nkNilLit, nkType:
  157. result = true
  158. else:
  159. # special rule for p(X) ~ f(...); this also works for stuff like
  160. # partial case statements, etc! - Not really ... :-/
  161. let v = lastSon(p)
  162. if isPatternParam(c, v) and v.sym.typ.kind == tyVarargs:
  163. var arglist: PNode
  164. if p.len <= n.len:
  165. for i in 0..<p.len - 1:
  166. if not matches(c, p[i], n[i]): return
  167. if p.len == n.len and lastSon(n).kind == nkHiddenStdConv and
  168. lastSon(n)[1].kind == nkBracket:
  169. # unpack varargs:
  170. let n = lastSon(n)[1]
  171. arglist = newNodeI(nkArgList, n.info, n.len)
  172. for i in 0..<n.len: arglist[i] = n[i]
  173. else:
  174. arglist = newNodeI(nkArgList, n.info, n.len - p.len + 1)
  175. # f(1, 2, 3)
  176. # p(X)
  177. for i in 0..n.len - p.len:
  178. arglist[i] = n[i + p.len - 1]
  179. return bindOrCheck(c, v.sym, arglist)
  180. elif p.len-1 == n.len:
  181. for i in 0..<p.len - 1:
  182. if not matches(c, p[i], n[i]): return
  183. arglist = newNodeI(nkArgList, n.info)
  184. return bindOrCheck(c, v.sym, arglist)
  185. if p.len == n.len:
  186. for i in 0..<p.len:
  187. if not matches(c, p[i], n[i]): return
  188. result = true
  189. proc matchStmtList(c: PPatternContext, p, n: PNode): PNode =
  190. proc matchRange(c: PPatternContext, p, n: PNode, i: int): bool =
  191. for j in 0..<p.len:
  192. if not matches(c, p[j], n[i+j]):
  193. # we need to undo any bindings:
  194. c.mapping = @[]
  195. c.mappingIsFull = false
  196. return false
  197. result = true
  198. if p.kind == nkStmtList and n.kind == p.kind and p.len < n.len:
  199. let n = flattenStmts(n)
  200. # no need to flatten 'p' here as that has already been done
  201. for i in 0..n.len - p.len:
  202. if matchRange(c, p, n, i):
  203. c.subMatch = true
  204. result = newNodeI(nkStmtList, n.info, 3)
  205. result[0] = extractRange(nkStmtList, n, 0, i-1)
  206. result[1] = extractRange(nkStmtList, n, i, i+p.len-1)
  207. result[2] = extractRange(nkStmtList, n, i+p.len, n.len-1)
  208. break
  209. elif matches(c, p, n):
  210. result = n
  211. proc aliasAnalysisRequested(params: PNode): bool =
  212. if params.len >= 2:
  213. for i in 1..<params.len:
  214. let param = params[i].sym
  215. if whichAlias(param) != aqNone: return true
  216. proc addToArgList(result, n: PNode) =
  217. if n.typ != nil and n.typ.kind != tyTyped:
  218. if n.kind != nkArgList: result.add(n)
  219. else:
  220. for i in 0..<n.len: result.add(n[i])
  221. proc applyRule*(c: PContext, s: PSym, n: PNode): PNode =
  222. ## returns a tree to semcheck if the rule triggered; nil otherwise
  223. var ctx: TPatternContext
  224. ctx.owner = s
  225. ctx.c = c
  226. ctx.formals = s.typ.len-1
  227. var m = matchStmtList(ctx, s.ast[patternPos], n)
  228. if isNil(m): return nil
  229. # each parameter should have been bound; we simply setup a call and
  230. # let semantic checking deal with the rest :-)
  231. result = newNodeI(nkCall, n.info)
  232. result.add(newSymNode(s, n.info))
  233. let params = s.typ.n
  234. let requiresAA = aliasAnalysisRequested(params)
  235. var args: PNode
  236. if requiresAA:
  237. args = newNodeI(nkArgList, n.info)
  238. for i in 1..<params.len:
  239. let param = params[i].sym
  240. let x = getLazy(ctx, param)
  241. # couldn't bind parameter:
  242. if isNil(x): return nil
  243. result.add(x)
  244. if requiresAA: addToArgList(args, x)
  245. # perform alias analysis here:
  246. if requiresAA:
  247. for i in 1..<params.len:
  248. var rs = result[i]
  249. let param = params[i].sym
  250. case whichAlias(param)
  251. of aqNone: discard
  252. of aqShouldAlias:
  253. # it suffices that it aliases for sure with *some* other param:
  254. var ok = false
  255. for arg in items(args):
  256. if arg != rs and aliases.isPartOf(rs, arg) == arYes:
  257. ok = true
  258. break
  259. # constraint not fulfilled:
  260. if not ok: return nil
  261. of aqNoAlias:
  262. # it MUST not alias with any other param:
  263. var ok = true
  264. for arg in items(args):
  265. if arg != rs and aliases.isPartOf(rs, arg) != arNo:
  266. ok = false
  267. break
  268. # constraint not fulfilled:
  269. if not ok: return nil
  270. markUsed(c, n.info, s)
  271. if ctx.subMatch:
  272. assert m.len == 3
  273. m[1] = result
  274. result = m