semparallel.nim 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. #
  2. #
  3. # The Nim Compiler
  4. # (c) Copyright 2015 Andreas Rumpf
  5. #
  6. # See the file "copying.txt", included in this
  7. # distribution, for details about the copyright.
  8. #
  9. ## Semantic checking for 'parallel'.
  10. # - codegen needs to support mSlice (+)
  11. # - lowerings must not perform unnecessary copies (+)
  12. # - slices should become "nocopy" to openArray (+)
  13. # - need to perform bound checks (+)
  14. #
  15. # - parallel needs to insert a barrier (+)
  16. # - passed arguments need to be ensured to be "const"
  17. # - what about 'f(a)'? --> f shouldn't have side effects anyway
  18. # - passed arrays need to be ensured not to alias
  19. # - passed slices need to be ensured to be disjoint (+)
  20. # - output slices need special logic (+)
  21. import
  22. ast, astalgo, idents, lowerings, magicsys, guards, sempass2, msgs,
  23. renderer, types
  24. from trees import getMagic
  25. from strutils import `%`
  26. discard """
  27. one major problem:
  28. spawn f(a[i])
  29. inc i
  30. spawn f(a[i])
  31. is valid, but
  32. spawn f(a[i])
  33. spawn f(a[i])
  34. inc i
  35. is not! However,
  36. spawn f(a[i])
  37. if guard: inc i
  38. spawn f(a[i])
  39. is not valid either! --> We need a flow dependent analysis here.
  40. However:
  41. while foo:
  42. spawn f(a[i])
  43. inc i
  44. spawn f(a[i])
  45. Is not valid either! --> We should really restrict 'inc' to loop endings?
  46. The heuristic that we implement here (that has no false positives) is: Usage
  47. of 'i' in a slice *after* we determined the stride is invalid!
  48. """
  49. type
  50. TDirection = enum
  51. ascending, descending
  52. MonotonicVar = object
  53. v, alias: PSym # to support the ordinary 'countup' iterator
  54. # we need to detect aliases
  55. lower, upper, stride: PNode
  56. dir: TDirection
  57. blacklisted: bool # blacklisted variables that are not monotonic
  58. AnalysisCtx = object
  59. locals: seq[MonotonicVar]
  60. slices: seq[tuple[x,a,b: PNode, spawnId: int, inLoop: bool]]
  61. guards: TModel # nested guards
  62. args: seq[PSym] # args must be deeply immutable
  63. spawns: int # we can check that at last 1 spawn is used in
  64. # the 'parallel' section
  65. currentSpawnId: int
  66. inLoop: int
  67. proc initAnalysisCtx(): AnalysisCtx =
  68. result.locals = @[]
  69. result.slices = @[]
  70. result.args = @[]
  71. result.guards = @[]
  72. proc lookupSlot(c: AnalysisCtx; s: PSym): int =
  73. for i in 0.. <c.locals.len:
  74. if c.locals[i].v == s or c.locals[i].alias == s: return i
  75. return -1
  76. proc getSlot(c: var AnalysisCtx; v: PSym): ptr MonotonicVar =
  77. let s = lookupSlot(c, v)
  78. if s >= 0: return addr(c.locals[s])
  79. let L = c.locals.len
  80. c.locals.setLen(L+1)
  81. c.locals[L].v = v
  82. return addr(c.locals[L])
  83. proc gatherArgs(c: var AnalysisCtx; n: PNode) =
  84. for i in 0.. <n.safeLen:
  85. let root = getRoot n[i]
  86. if root != nil:
  87. block addRoot:
  88. for r in items(c.args):
  89. if r == root: break addRoot
  90. c.args.add root
  91. gatherArgs(c, n[i])
  92. proc isSingleAssignable(n: PNode): bool =
  93. n.kind == nkSym and (let s = n.sym;
  94. s.kind in {skTemp, skForVar, skLet} and
  95. {sfAddrTaken, sfGlobal} * s.flags == {})
  96. proc isLocal(n: PNode): bool =
  97. n.kind == nkSym and (let s = n.sym;
  98. s.kind in {skResult, skTemp, skForVar, skVar, skLet} and
  99. {sfAddrTaken, sfGlobal} * s.flags == {})
  100. proc checkLocal(c: AnalysisCtx; n: PNode) =
  101. if isLocal(n):
  102. let s = c.lookupSlot(n.sym)
  103. if s >= 0 and c.locals[s].stride != nil:
  104. localError(n.info, "invalid usage of counter after increment")
  105. else:
  106. for i in 0 .. <n.safeLen: checkLocal(c, n.sons[i])
  107. template `?`(x): untyped = x.renderTree
  108. proc checkLe(c: AnalysisCtx; a, b: PNode) =
  109. case proveLe(c.guards, a, b)
  110. of impUnknown:
  111. localError(a.info, "cannot prove: " & ?a & " <= " & ?b & " (bounds check)")
  112. of impYes: discard
  113. of impNo:
  114. localError(a.info, "can prove: " & ?a & " > " & ?b & " (bounds check)")
  115. proc checkBounds(c: AnalysisCtx; arr, idx: PNode) =
  116. checkLe(c, arr.lowBound, idx)
  117. checkLe(c, idx, arr.highBound)
  118. proc addLowerBoundAsFacts(c: var AnalysisCtx) =
  119. for v in c.locals:
  120. if not v.blacklisted:
  121. c.guards.addFactLe(v.lower, newSymNode(v.v))
  122. proc addSlice(c: var AnalysisCtx; n: PNode; x, le, ri: PNode) =
  123. checkLocal(c, n)
  124. let le = le.canon
  125. let ri = ri.canon
  126. # perform static bounds checking here; and not later!
  127. let oldState = c.guards.len
  128. addLowerBoundAsFacts(c)
  129. c.checkBounds(x, le)
  130. c.checkBounds(x, ri)
  131. c.guards.setLen(oldState)
  132. c.slices.add((x, le, ri, c.currentSpawnId, c.inLoop > 0))
  133. proc overlap(m: TModel; x,y,c,d: PNode) =
  134. # X..Y and C..D overlap iff (X <= D and C <= Y)
  135. case proveLe(m, c, y)
  136. of impUnknown:
  137. case proveLe(m, x, d)
  138. of impNo: discard
  139. of impUnknown, impYes:
  140. localError(x.info,
  141. "cannot prove: $# > $#; required for ($#)..($#) disjoint from ($#)..($#)" %
  142. [?c, ?y, ?x, ?y, ?c, ?d])
  143. of impYes:
  144. case proveLe(m, x, d)
  145. of impUnknown:
  146. localError(x.info,
  147. "cannot prove: $# > $#; required for ($#)..($#) disjoint from ($#)..($#)" %
  148. [?x, ?d, ?x, ?y, ?c, ?d])
  149. of impYes:
  150. localError(x.info, "($#)..($#) not disjoint from ($#)..($#)" %
  151. [?c, ?y, ?x, ?y, ?c, ?d])
  152. of impNo: discard
  153. of impNo: discard
  154. proc stride(c: AnalysisCtx; n: PNode): BiggestInt =
  155. if isLocal(n):
  156. let s = c.lookupSlot(n.sym)
  157. if s >= 0 and c.locals[s].stride != nil:
  158. result = c.locals[s].stride.intVal
  159. else:
  160. for i in 0 .. <n.safeLen: result += stride(c, n.sons[i])
  161. proc subStride(c: AnalysisCtx; n: PNode): PNode =
  162. # substitute with stride:
  163. if isLocal(n):
  164. let s = c.lookupSlot(n.sym)
  165. if s >= 0 and c.locals[s].stride != nil:
  166. result = n +@ c.locals[s].stride.intVal
  167. else:
  168. result = n
  169. elif n.safeLen > 0:
  170. result = shallowCopy(n)
  171. for i in 0 .. <n.len: result.sons[i] = subStride(c, n.sons[i])
  172. else:
  173. result = n
  174. proc checkSlicesAreDisjoint(c: var AnalysisCtx) =
  175. # this is the only thing that we need to perform after we have traversed
  176. # the whole tree so that the strides are available.
  177. # First we need to add all the computed lower bounds:
  178. addLowerBoundAsFacts(c)
  179. # Every slice used in a loop needs to be disjoint with itself:
  180. for x,a,b,id,inLoop in items(c.slices):
  181. if inLoop: overlap(c.guards, a,b, c.subStride(a), c.subStride(b))
  182. # Another tricky example is:
  183. # while true:
  184. # spawn f(a[i])
  185. # spawn f(a[i+1])
  186. # inc i # inc i, 2 would be correct here
  187. #
  188. # Or even worse:
  189. # while true:
  190. # spawn f(a[i+1 .. i+3])
  191. # spawn f(a[i+4 .. i+5])
  192. # inc i, 4
  193. # Prove that i*k*stride + 3 != i*k'*stride + 5
  194. # For the correct example this amounts to
  195. # i*k*2 != i*k'*2 + 1
  196. # which is true.
  197. # For now, we don't try to prove things like that at all, even though it'd
  198. # be feasible for many useful examples. Instead we attach the slice to
  199. # a spawn and if the attached spawns differ, we bail out:
  200. for i in 0 .. high(c.slices):
  201. for j in i+1 .. high(c.slices):
  202. let x = c.slices[i]
  203. let y = c.slices[j]
  204. if x.spawnId != y.spawnId and guards.sameTree(x.x, y.x):
  205. if not x.inLoop or not y.inLoop:
  206. # XXX strictly speaking, 'or' is not correct here and it needs to
  207. # be 'and'. However this prevents too many obviously correct programs
  208. # like f(a[0..x]); for i in x+1 .. a.high: f(a[i])
  209. overlap(c.guards, x.a, x.b, y.a, y.b)
  210. elif (let k = simpleSlice(x.a, x.b); let m = simpleSlice(y.a, y.b);
  211. k >= 0 and m >= 0):
  212. # ah I cannot resist the temptation and add another sweet heuristic:
  213. # if both slices have the form (i+k)..(i+k) and (i+m)..(i+m) we
  214. # check they are disjoint and k < stride and m < stride:
  215. overlap(c.guards, x.a, x.b, y.a, y.b)
  216. let stride = min(c.stride(x.a), c.stride(y.a))
  217. if k < stride and m < stride:
  218. discard
  219. else:
  220. localError(x.x.info, "cannot prove ($#)..($#) disjoint from ($#)..($#)" %
  221. [?x.a, ?x.b, ?y.a, ?y.b])
  222. else:
  223. localError(x.x.info, "cannot prove ($#)..($#) disjoint from ($#)..($#)" %
  224. [?x.a, ?x.b, ?y.a, ?y.b])
  225. proc analyse(c: var AnalysisCtx; n: PNode)
  226. proc analyseSons(c: var AnalysisCtx; n: PNode) =
  227. for i in 0 .. <safeLen(n): analyse(c, n[i])
  228. proc min(a, b: PNode): PNode =
  229. if a.isNil: result = b
  230. elif a.intVal < b.intVal: result = a
  231. else: result = b
  232. proc fromSystem(op: PSym): bool = sfSystemModule in getModule(op).flags
  233. template pushSpawnId(c, body) {.dirty.} =
  234. inc c.spawns
  235. let oldSpawnId = c.currentSpawnId
  236. c.currentSpawnId = c.spawns
  237. body
  238. c.currentSpawnId = oldSpawnId
  239. proc analyseCall(c: var AnalysisCtx; n: PNode; op: PSym) =
  240. if op.magic == mSpawn:
  241. pushSpawnId(c):
  242. gatherArgs(c, n[1])
  243. analyseSons(c, n)
  244. elif op.magic == mInc or (op.name.s == "+=" and op.fromSystem):
  245. if n[1].isLocal:
  246. let incr = n[2].skipConv
  247. if incr.kind in {nkCharLit..nkUInt32Lit} and incr.intVal > 0:
  248. let slot = c.getSlot(n[1].sym)
  249. slot.stride = min(slot.stride, incr)
  250. analyseSons(c, n)
  251. elif op.name.s == "[]" and op.fromSystem:
  252. let slice = n[2].skipStmtList
  253. c.addSlice(n, n[1], slice[1], slice[2])
  254. analyseSons(c, n)
  255. elif op.name.s == "[]=" and op.fromSystem:
  256. let slice = n[2].skipStmtList
  257. c.addSlice(n, n[1], slice[1], slice[2])
  258. analyseSons(c, n)
  259. else:
  260. analyseSons(c, n)
  261. proc analyseCase(c: var AnalysisCtx; n: PNode) =
  262. analyse(c, n.sons[0])
  263. let oldFacts = c.guards.len
  264. for i in 1.. <n.len:
  265. let branch = n.sons[i]
  266. setLen(c.guards, oldFacts)
  267. addCaseBranchFacts(c.guards, n, i)
  268. for i in 0 .. <branch.len:
  269. analyse(c, branch.sons[i])
  270. setLen(c.guards, oldFacts)
  271. proc analyseIf(c: var AnalysisCtx; n: PNode) =
  272. analyse(c, n.sons[0].sons[0])
  273. let oldFacts = c.guards.len
  274. addFact(c.guards, canon(n.sons[0].sons[0]))
  275. analyse(c, n.sons[0].sons[1])
  276. for i in 1.. <n.len:
  277. let branch = n.sons[i]
  278. setLen(c.guards, oldFacts)
  279. for j in 0..i-1:
  280. addFactNeg(c.guards, canon(n.sons[j].sons[0]))
  281. if branch.len > 1:
  282. addFact(c.guards, canon(branch.sons[0]))
  283. for i in 0 .. <branch.len:
  284. analyse(c, branch.sons[i])
  285. setLen(c.guards, oldFacts)
  286. proc analyse(c: var AnalysisCtx; n: PNode) =
  287. case n.kind
  288. of nkAsgn, nkFastAsgn:
  289. let y = n[1].skipConv
  290. if n[0].isSingleAssignable and y.isLocal:
  291. let slot = c.getSlot(y.sym)
  292. slot.alias = n[0].sym
  293. elif n[0].isLocal:
  294. # since we already ensure sfAddrTaken is not in s.flags, we only need to
  295. # prevent direct assignments to the monotonic variable:
  296. let slot = c.getSlot(n[0].sym)
  297. slot.blacklisted = true
  298. invalidateFacts(c.guards, n[0])
  299. let value = n[1]
  300. if getMagic(value) == mSpawn:
  301. pushSpawnId(c):
  302. gatherArgs(c, value[1])
  303. analyseSons(c, value[1])
  304. analyse(c, n[0])
  305. else:
  306. analyseSons(c, n)
  307. addAsgnFact(c.guards, n[0], y)
  308. of nkCallKinds:
  309. # direct call:
  310. if n[0].kind == nkSym: analyseCall(c, n, n[0].sym)
  311. else: analyseSons(c, n)
  312. of nkBracketExpr:
  313. c.addSlice(n, n[0], n[1], n[1])
  314. analyseSons(c, n)
  315. of nkReturnStmt, nkRaiseStmt, nkTryStmt:
  316. localError(n.info, "invalid control flow for 'parallel'")
  317. # 'break' that leaves the 'parallel' section is not valid either
  318. # or maybe we should generate a 'try' XXX
  319. of nkVarSection, nkLetSection:
  320. for it in n:
  321. let value = it.lastSon
  322. let isSpawned = getMagic(value) == mSpawn
  323. if isSpawned:
  324. pushSpawnId(c):
  325. gatherArgs(c, value[1])
  326. analyseSons(c, value[1])
  327. if value.kind != nkEmpty:
  328. for j in 0 .. it.len-3:
  329. if it[j].isLocal:
  330. let slot = c.getSlot(it[j].sym)
  331. if slot.lower.isNil: slot.lower = value
  332. else: internalError(it.info, "slot already has a lower bound")
  333. if not isSpawned: analyse(c, value)
  334. of nkCaseStmt: analyseCase(c, n)
  335. of nkWhen, nkIfStmt, nkIfExpr: analyseIf(c, n)
  336. of nkWhileStmt:
  337. analyse(c, n.sons[0])
  338. # 'while true' loop?
  339. inc c.inLoop
  340. if isTrue(n.sons[0]):
  341. analyseSons(c, n.sons[1])
  342. else:
  343. # loop may never execute:
  344. let oldState = c.locals.len
  345. let oldFacts = c.guards.len
  346. addFact(c.guards, canon(n.sons[0]))
  347. analyse(c, n.sons[1])
  348. setLen(c.locals, oldState)
  349. setLen(c.guards, oldFacts)
  350. # we know after the loop the negation holds:
  351. if not hasSubnodeWith(n.sons[1], nkBreakStmt):
  352. addFactNeg(c.guards, canon(n.sons[0]))
  353. dec c.inLoop
  354. of nkTypeSection, nkProcDef, nkConverterDef, nkMethodDef, nkIteratorDef,
  355. nkMacroDef, nkTemplateDef, nkConstSection, nkPragma, nkFuncDef:
  356. discard
  357. else:
  358. analyseSons(c, n)
  359. proc transformSlices(n: PNode): PNode =
  360. if n.kind in nkCallKinds and n[0].kind == nkSym:
  361. let op = n[0].sym
  362. if op.name.s == "[]" and op.fromSystem:
  363. result = copyNode(n)
  364. let opSlice = newSymNode(createMagic("slice", mSlice))
  365. opSlice.typ = getSysType(tyInt)
  366. result.add opSlice
  367. result.add n[1]
  368. let slice = n[2].skipStmtList
  369. result.add slice[1]
  370. result.add slice[2]
  371. return result
  372. if n.safeLen > 0:
  373. result = shallowCopy(n)
  374. for i in 0 .. < n.len:
  375. result.sons[i] = transformSlices(n.sons[i])
  376. else:
  377. result = n
  378. proc transformSpawn(owner: PSym; n, barrier: PNode): PNode
  379. proc transformSpawnSons(owner: PSym; n, barrier: PNode): PNode =
  380. result = shallowCopy(n)
  381. for i in 0 .. < n.len:
  382. result.sons[i] = transformSpawn(owner, n.sons[i], barrier)
  383. proc transformSpawn(owner: PSym; n, barrier: PNode): PNode =
  384. case n.kind
  385. of nkVarSection, nkLetSection:
  386. result = nil
  387. for it in n:
  388. let b = it.lastSon
  389. if getMagic(b) == mSpawn:
  390. if it.len != 3: localError(it.info, "invalid context for 'spawn'")
  391. let m = transformSlices(b)
  392. if result.isNil:
  393. result = newNodeI(nkStmtList, n.info)
  394. result.add n
  395. let t = b[1][0].typ.sons[0]
  396. if spawnResult(t, true) == srByVar:
  397. result.add wrapProcForSpawn(owner, m, b.typ, barrier, it[0])
  398. it.sons[it.len-1] = emptyNode
  399. else:
  400. it.sons[it.len-1] = wrapProcForSpawn(owner, m, b.typ, barrier, nil)
  401. if result.isNil: result = n
  402. of nkAsgn, nkFastAsgn:
  403. let b = n[1]
  404. if getMagic(b) == mSpawn and (let t = b[1][0].typ.sons[0];
  405. spawnResult(t, true) == srByVar):
  406. let m = transformSlices(b)
  407. return wrapProcForSpawn(owner, m, b.typ, barrier, n[0])
  408. result = transformSpawnSons(owner, n, barrier)
  409. of nkCallKinds:
  410. if getMagic(n) == mSpawn:
  411. result = transformSlices(n)
  412. return wrapProcForSpawn(owner, result, n.typ, barrier, nil)
  413. result = transformSpawnSons(owner, n, barrier)
  414. elif n.safeLen > 0:
  415. result = transformSpawnSons(owner, n, barrier)
  416. else:
  417. result = n
  418. proc checkArgs(a: var AnalysisCtx; n: PNode) =
  419. discard "too implement"
  420. proc generateAliasChecks(a: AnalysisCtx; result: PNode) =
  421. discard "too implement"
  422. proc liftParallel*(owner: PSym; n: PNode): PNode =
  423. # this needs to be called after the 'for' loop elimination
  424. # first pass:
  425. # - detect monotonic local integer variables
  426. # - detect used slices
  427. # - detect used arguments
  428. #echo "PAR ", renderTree(n)
  429. var a = initAnalysisCtx()
  430. let body = n.lastSon
  431. analyse(a, body)
  432. if a.spawns == 0:
  433. localError(n.info, "'parallel' section without 'spawn'")
  434. checkSlicesAreDisjoint(a)
  435. checkArgs(a, body)
  436. var varSection = newNodeI(nkVarSection, n.info)
  437. var temp = newSym(skTemp, getIdent"barrier", owner, n.info)
  438. temp.typ = magicsys.getCompilerProc("Barrier").typ
  439. incl(temp.flags, sfFromGeneric)
  440. let tempNode = newSymNode(temp)
  441. varSection.addVar tempNode
  442. let barrier = genAddrOf(tempNode)
  443. result = newNodeI(nkStmtList, n.info)
  444. generateAliasChecks(a, result)
  445. result.add varSection
  446. result.add callCodegenProc("openBarrier", barrier)
  447. result.add transformSpawn(owner, body, barrier)
  448. result.add callCodegenProc("closeBarrier", barrier)