semparallel.nim 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. #
  2. #
  3. # The Nim Compiler
  4. # (c) Copyright 2015 Andreas Rumpf
  5. #
  6. # See the file "copying.txt", included in this
  7. # distribution, for details about the copyright.
  8. #
  9. ## Semantic checking for 'parallel'.
  10. # - codegen needs to support mSlice (+)
  11. # - lowerings must not perform unnecessary copies (+)
  12. # - slices should become "nocopy" to openArray (+)
  13. # - need to perform bound checks (+)
  14. #
  15. # - parallel needs to insert a barrier (+)
  16. # - passed arguments need to be ensured to be "const"
  17. # - what about 'f(a)'? --> f shouldn't have side effects anyway
  18. # - passed arrays need to be ensured not to alias
  19. # - passed slices need to be ensured to be disjoint (+)
  20. # - output slices need special logic (+)
  21. import
  22. ast, astalgo, idents, lowerings, magicsys, guards, msgs,
  23. renderer, types, modulegraphs, options, spawn, lineinfos
  24. from trees import getMagic, isTrue, getRoot
  25. from strutils import `%`
  26. discard """
  27. one major problem:
  28. spawn f(a[i])
  29. inc i
  30. spawn f(a[i])
  31. is valid, but
  32. spawn f(a[i])
  33. spawn f(a[i])
  34. inc i
  35. is not! However,
  36. spawn f(a[i])
  37. if guard: inc i
  38. spawn f(a[i])
  39. is not valid either! --> We need a flow dependent analysis here.
  40. However:
  41. while foo:
  42. spawn f(a[i])
  43. inc i
  44. spawn f(a[i])
  45. Is not valid either! --> We should really restrict 'inc' to loop endings?
  46. The heuristic that we implement here (that has no false positives) is: Usage
  47. of 'i' in a slice *after* we determined the stride is invalid!
  48. """
  49. type
  50. TDirection = enum
  51. ascending, descending
  52. MonotonicVar = object
  53. v, alias: PSym # to support the ordinary 'countup' iterator
  54. # we need to detect aliases
  55. lower, upper, stride: PNode
  56. dir: TDirection
  57. blacklisted: bool # blacklisted variables that are not monotonic
  58. AnalysisCtx = object
  59. locals: seq[MonotonicVar]
  60. slices: seq[tuple[x,a,b: PNode, spawnId: int, inLoop: bool]]
  61. guards: TModel # nested guards
  62. args: seq[PSym] # args must be deeply immutable
  63. spawns: int # we can check that at last 1 spawn is used in
  64. # the 'parallel' section
  65. currentSpawnId: int
  66. inLoop: int
  67. graph: ModuleGraph
  68. proc initAnalysisCtx(g: ModuleGraph): AnalysisCtx =
  69. result.locals = @[]
  70. result.slices = @[]
  71. result.args = @[]
  72. result.guards.s = @[]
  73. result.guards.g = g
  74. result.graph = g
  75. proc lookupSlot(c: AnalysisCtx; s: PSym): int =
  76. for i in 0..<c.locals.len:
  77. if c.locals[i].v == s or c.locals[i].alias == s: return i
  78. return -1
  79. proc getSlot(c: var AnalysisCtx; v: PSym): ptr MonotonicVar =
  80. let s = lookupSlot(c, v)
  81. if s >= 0: return addr(c.locals[s])
  82. c.locals.setLen(c.locals.len+1)
  83. c.locals[^1].v = v
  84. return addr(c.locals[^1])
  85. proc gatherArgs(c: var AnalysisCtx; n: PNode) =
  86. for i in 0..<n.safeLen:
  87. let root = getRoot n[i]
  88. if root != nil:
  89. block addRoot:
  90. for r in items(c.args):
  91. if r == root: break addRoot
  92. c.args.add root
  93. gatherArgs(c, n[i])
  94. proc isSingleAssignable(n: PNode): bool =
  95. n.kind == nkSym and (let s = n.sym;
  96. s.kind in {skTemp, skForVar, skLet} and
  97. {sfAddrTaken, sfGlobal} * s.flags == {})
  98. proc isLocal(n: PNode): bool =
  99. n.kind == nkSym and (let s = n.sym;
  100. s.kind in {skResult, skTemp, skForVar, skVar, skLet} and
  101. {sfAddrTaken, sfGlobal} * s.flags == {})
  102. proc checkLocal(c: AnalysisCtx; n: PNode) =
  103. if isLocal(n):
  104. let s = c.lookupSlot(n.sym)
  105. if s >= 0 and c.locals[s].stride != nil:
  106. localError(c.graph.config, n.info, "invalid usage of counter after increment")
  107. else:
  108. for i in 0..<n.safeLen: checkLocal(c, n[i])
  109. template `?`(x): untyped = x.renderTree
  110. proc checkLe(c: AnalysisCtx; a, b: PNode) =
  111. case proveLe(c.guards, a, b)
  112. of impUnknown:
  113. message(c.graph.config, a.info, warnStaticIndexCheck,
  114. "cannot prove: " & ?a & " <= " & ?b)
  115. of impYes: discard
  116. of impNo:
  117. message(c.graph.config, a.info, warnStaticIndexCheck,
  118. "can prove: " & ?a & " > " & ?b)
  119. proc checkBounds(c: AnalysisCtx; arr, idx: PNode) =
  120. checkLe(c, lowBound(c.graph.config, arr), idx)
  121. checkLe(c, idx, highBound(c.graph.config, arr, c.graph.operators))
  122. proc addLowerBoundAsFacts(c: var AnalysisCtx) =
  123. for v in c.locals:
  124. if not v.blacklisted:
  125. c.guards.addFactLe(v.lower, newSymNode(v.v))
  126. proc addSlice(c: var AnalysisCtx; n: PNode; x, le, ri: PNode) =
  127. checkLocal(c, n)
  128. let le = le.canon(c.graph.operators)
  129. let ri = ri.canon(c.graph.operators)
  130. # perform static bounds checking here; and not later!
  131. let oldState = c.guards.s.len
  132. addLowerBoundAsFacts(c)
  133. c.checkBounds(x, le)
  134. c.checkBounds(x, ri)
  135. c.guards.s.setLen(oldState)
  136. c.slices.add((x, le, ri, c.currentSpawnId, c.inLoop > 0))
  137. proc overlap(m: TModel; conf: ConfigRef; x,y,c,d: PNode) =
  138. # X..Y and C..D overlap iff (X <= D and C <= Y)
  139. case proveLe(m, c, y)
  140. of impUnknown:
  141. case proveLe(m, x, d)
  142. of impNo: discard
  143. of impUnknown, impYes:
  144. message(conf, x.info, warnStaticIndexCheck,
  145. "cannot prove: $# > $#; required for ($#)..($#) disjoint from ($#)..($#)" %
  146. [?c, ?y, ?x, ?y, ?c, ?d])
  147. of impYes:
  148. case proveLe(m, x, d)
  149. of impUnknown:
  150. message(conf, x.info, warnStaticIndexCheck,
  151. "cannot prove: $# > $#; required for ($#)..($#) disjoint from ($#)..($#)" %
  152. [?x, ?d, ?x, ?y, ?c, ?d])
  153. of impYes:
  154. message(conf, x.info, warnStaticIndexCheck, "($#)..($#) not disjoint from ($#)..($#)" %
  155. [?c, ?y, ?x, ?y, ?c, ?d])
  156. of impNo: discard
  157. of impNo: discard
  158. proc stride(c: AnalysisCtx; n: PNode): BiggestInt =
  159. if isLocal(n):
  160. let s = c.lookupSlot(n.sym)
  161. if s >= 0 and c.locals[s].stride != nil:
  162. result = c.locals[s].stride.intVal
  163. else:
  164. for i in 0..<n.safeLen: result += stride(c, n[i])
  165. proc subStride(c: AnalysisCtx; n: PNode): PNode =
  166. # substitute with stride:
  167. if isLocal(n):
  168. let s = c.lookupSlot(n.sym)
  169. if s >= 0 and c.locals[s].stride != nil:
  170. result = buildAdd(n, c.locals[s].stride.intVal, c.graph.operators)
  171. else:
  172. result = n
  173. elif n.safeLen > 0:
  174. result = shallowCopy(n)
  175. for i in 0..<n.len: result[i] = subStride(c, n[i])
  176. else:
  177. result = n
  178. proc checkSlicesAreDisjoint(c: var AnalysisCtx) =
  179. # this is the only thing that we need to perform after we have traversed
  180. # the whole tree so that the strides are available.
  181. # First we need to add all the computed lower bounds:
  182. addLowerBoundAsFacts(c)
  183. # Every slice used in a loop needs to be disjoint with itself:
  184. for x,a,b,id,inLoop in items(c.slices):
  185. if inLoop: overlap(c.guards, c.graph.config, a,b, c.subStride(a), c.subStride(b))
  186. # Another tricky example is:
  187. # while true:
  188. # spawn f(a[i])
  189. # spawn f(a[i+1])
  190. # inc i # inc i, 2 would be correct here
  191. #
  192. # Or even worse:
  193. # while true:
  194. # spawn f(a[i+1..i+3])
  195. # spawn f(a[i+4..i+5])
  196. # inc i, 4
  197. # Prove that i*k*stride + 3 != i*k'*stride + 5
  198. # For the correct example this amounts to
  199. # i*k*2 != i*k'*2 + 1
  200. # which is true.
  201. # For now, we don't try to prove things like that at all, even though it'd
  202. # be feasible for many useful examples. Instead we attach the slice to
  203. # a spawn and if the attached spawns differ, we bail out:
  204. for i in 0..high(c.slices):
  205. for j in i+1..high(c.slices):
  206. let x = c.slices[i]
  207. let y = c.slices[j]
  208. if x.spawnId != y.spawnId and guards.sameTree(x.x, y.x):
  209. if not x.inLoop or not y.inLoop:
  210. # XXX strictly speaking, 'or' is not correct here and it needs to
  211. # be 'and'. However this prevents too many obviously correct programs
  212. # like f(a[0..x]); for i in x+1..a.high: f(a[i])
  213. overlap(c.guards, c.graph.config, x.a, x.b, y.a, y.b)
  214. elif (let k = simpleSlice(x.a, x.b); let m = simpleSlice(y.a, y.b);
  215. k >= 0 and m >= 0):
  216. # ah I cannot resist the temptation and add another sweet heuristic:
  217. # if both slices have the form (i+k)..(i+k) and (i+m)..(i+m) we
  218. # check they are disjoint and k < stride and m < stride:
  219. overlap(c.guards, c.graph.config, x.a, x.b, y.a, y.b)
  220. let stride = min(c.stride(x.a), c.stride(y.a))
  221. if k < stride and m < stride:
  222. discard
  223. else:
  224. localError(c.graph.config, x.x.info, "cannot prove ($#)..($#) disjoint from ($#)..($#)" %
  225. [?x.a, ?x.b, ?y.a, ?y.b])
  226. else:
  227. localError(c.graph.config, x.x.info, "cannot prove ($#)..($#) disjoint from ($#)..($#)" %
  228. [?x.a, ?x.b, ?y.a, ?y.b])
  229. proc analyse(c: var AnalysisCtx; n: PNode)
  230. proc analyseSons(c: var AnalysisCtx; n: PNode) =
  231. for i in 0..<n.safeLen: analyse(c, n[i])
  232. proc min(a, b: PNode): PNode =
  233. if a.isNil: result = b
  234. elif a.intVal < b.intVal: result = a
  235. else: result = b
  236. template pushSpawnId(c, body) {.dirty.} =
  237. inc c.spawns
  238. let oldSpawnId = c.currentSpawnId
  239. c.currentSpawnId = c.spawns
  240. body
  241. c.currentSpawnId = oldSpawnId
  242. proc analyseCall(c: var AnalysisCtx; n: PNode; op: PSym) =
  243. if op.magic == mSpawn:
  244. pushSpawnId(c):
  245. gatherArgs(c, n[1])
  246. analyseSons(c, n)
  247. elif op.magic == mInc or (op.name.s == "+=" and op.fromSystem):
  248. if n[1].isLocal:
  249. let incr = n[2].skipConv
  250. if incr.kind in {nkCharLit..nkUInt32Lit} and incr.intVal > 0:
  251. let slot = c.getSlot(n[1].sym)
  252. slot.stride = min(slot.stride, incr)
  253. analyseSons(c, n)
  254. elif op.name.s == "[]" and op.fromSystem:
  255. let slice = n[2].skipStmtList
  256. c.addSlice(n, n[1], slice[1], slice[2])
  257. analyseSons(c, n)
  258. elif op.name.s == "[]=" and op.fromSystem:
  259. let slice = n[2].skipStmtList
  260. c.addSlice(n, n[1], slice[1], slice[2])
  261. analyseSons(c, n)
  262. else:
  263. analyseSons(c, n)
  264. proc analyseCase(c: var AnalysisCtx; n: PNode) =
  265. analyse(c, n[0])
  266. let oldFacts = c.guards.s.len
  267. for i in 1..<n.len:
  268. let branch = n[i]
  269. setLen(c.guards.s, oldFacts)
  270. addCaseBranchFacts(c.guards, n, i)
  271. for i in 0..<branch.len:
  272. analyse(c, branch[i])
  273. setLen(c.guards.s, oldFacts)
  274. proc analyseIf(c: var AnalysisCtx; n: PNode) =
  275. analyse(c, n[0][0])
  276. let oldFacts = c.guards.s.len
  277. addFact(c.guards, canon(n[0][0], c.graph.operators))
  278. analyse(c, n[0][1])
  279. for i in 1..<n.len:
  280. let branch = n[i]
  281. setLen(c.guards.s, oldFacts)
  282. for j in 0..i-1:
  283. addFactNeg(c.guards, canon(n[j][0], c.graph.operators))
  284. if branch.len > 1:
  285. addFact(c.guards, canon(branch[0], c.graph.operators))
  286. for i in 0..<branch.len:
  287. analyse(c, branch[i])
  288. setLen(c.guards.s, oldFacts)
  289. proc analyse(c: var AnalysisCtx; n: PNode) =
  290. case n.kind
  291. of nkAsgn, nkFastAsgn, nkSinkAsgn:
  292. let y = n[1].skipConv
  293. if n[0].isSingleAssignable and y.isLocal:
  294. let slot = c.getSlot(y.sym)
  295. slot.alias = n[0].sym
  296. elif n[0].isLocal:
  297. # since we already ensure sfAddrTaken is not in s.flags, we only need to
  298. # prevent direct assignments to the monotonic variable:
  299. let slot = c.getSlot(n[0].sym)
  300. slot.blacklisted = true
  301. invalidateFacts(c.guards, n[0])
  302. let value = n[1]
  303. if getMagic(value) == mSpawn:
  304. pushSpawnId(c):
  305. gatherArgs(c, value[1])
  306. analyseSons(c, value[1])
  307. analyse(c, n[0])
  308. else:
  309. analyseSons(c, n)
  310. addAsgnFact(c.guards, n[0], y)
  311. of nkCallKinds:
  312. # direct call:
  313. if n[0].kind == nkSym: analyseCall(c, n, n[0].sym)
  314. else: analyseSons(c, n)
  315. of nkBracketExpr:
  316. if n[0].typ != nil and skipTypes(n[0].typ, abstractVar).kind != tyTuple:
  317. c.addSlice(n, n[0], n[1], n[1])
  318. analyseSons(c, n)
  319. of nkReturnStmt, nkRaiseStmt, nkTryStmt, nkHiddenTryStmt:
  320. localError(c.graph.config, n.info, "invalid control flow for 'parallel'")
  321. # 'break' that leaves the 'parallel' section is not valid either
  322. # or maybe we should generate a 'try' XXX
  323. of nkVarSection, nkLetSection:
  324. for it in n:
  325. let value = it.lastSon
  326. let isSpawned = getMagic(value) == mSpawn
  327. if isSpawned:
  328. pushSpawnId(c):
  329. gatherArgs(c, value[1])
  330. analyseSons(c, value[1])
  331. if value.kind != nkEmpty:
  332. for j in 0..<it.len-2:
  333. if it[j].isLocal:
  334. let slot = c.getSlot(it[j].sym)
  335. if slot.lower.isNil: slot.lower = value
  336. else: internalError(c.graph.config, it.info, "slot already has a lower bound")
  337. if not isSpawned: analyse(c, value)
  338. of nkCaseStmt: analyseCase(c, n)
  339. of nkWhen, nkIfStmt, nkIfExpr: analyseIf(c, n)
  340. of nkWhileStmt:
  341. analyse(c, n[0])
  342. # 'while true' loop?
  343. inc c.inLoop
  344. if isTrue(n[0]):
  345. analyseSons(c, n[1])
  346. else:
  347. # loop may never execute:
  348. let oldState = c.locals.len
  349. let oldFacts = c.guards.s.len
  350. addFact(c.guards, canon(n[0], c.graph.operators))
  351. analyse(c, n[1])
  352. setLen(c.locals, oldState)
  353. setLen(c.guards.s, oldFacts)
  354. # we know after the loop the negation holds:
  355. if not hasSubnodeWith(n[1], nkBreakStmt):
  356. addFactNeg(c.guards, canon(n[0], c.graph.operators))
  357. dec c.inLoop
  358. of nkTypeSection, nkProcDef, nkConverterDef, nkMethodDef, nkIteratorDef,
  359. nkMacroDef, nkTemplateDef, nkConstSection, nkPragma, nkFuncDef,
  360. nkMixinStmt, nkBindStmt, nkExportStmt:
  361. discard
  362. else:
  363. analyseSons(c, n)
  364. proc transformSlices(g: ModuleGraph; idgen: IdGenerator; n: PNode): PNode =
  365. if n.kind in nkCallKinds and n[0].kind == nkSym:
  366. let op = n[0].sym
  367. if op.name.s == "[]" and op.fromSystem:
  368. result = copyNode(n)
  369. let opSlice = newSymNode(createMagic(g, idgen, "slice", mSlice))
  370. opSlice.typ = getSysType(g, n.info, tyInt)
  371. result.add opSlice
  372. result.add n[1]
  373. let slice = n[2].skipStmtList
  374. result.add slice[1]
  375. result.add slice[2]
  376. return result
  377. if n.safeLen > 0:
  378. result = shallowCopy(n)
  379. for i in 0..<n.len:
  380. result[i] = transformSlices(g, idgen, n[i])
  381. else:
  382. result = n
  383. proc transformSpawn(g: ModuleGraph; idgen: IdGenerator; owner: PSym; n, barrier: PNode): PNode
  384. proc transformSpawnSons(g: ModuleGraph; idgen: IdGenerator; owner: PSym; n, barrier: PNode): PNode =
  385. result = shallowCopy(n)
  386. for i in 0..<n.len:
  387. result[i] = transformSpawn(g, idgen, owner, n[i], barrier)
  388. proc transformSpawn(g: ModuleGraph; idgen: IdGenerator; owner: PSym; n, barrier: PNode): PNode =
  389. case n.kind
  390. of nkVarSection, nkLetSection:
  391. result = nil
  392. for it in n:
  393. let b = it.lastSon
  394. if getMagic(b) == mSpawn:
  395. if it.len != 3: localError(g.config, it.info, "invalid context for 'spawn'")
  396. let m = transformSlices(g, idgen, b)
  397. if result.isNil:
  398. result = newNodeI(nkStmtList, n.info)
  399. result.add n
  400. let t = b[1][0].typ[0]
  401. if spawnResult(t, true) == srByVar:
  402. result.add wrapProcForSpawn(g, idgen, owner, m, b.typ, barrier, it[0])
  403. it[^1] = newNodeI(nkEmpty, it.info)
  404. else:
  405. it[^1] = wrapProcForSpawn(g, idgen, owner, m, b.typ, barrier, nil)
  406. if result.isNil: result = n
  407. of nkAsgn, nkFastAsgn, nkSinkAsgn:
  408. let b = n[1]
  409. if getMagic(b) == mSpawn and (let t = b[1][0].typ[0];
  410. spawnResult(t, true) == srByVar):
  411. let m = transformSlices(g, idgen, b)
  412. return wrapProcForSpawn(g, idgen, owner, m, b.typ, barrier, n[0])
  413. result = transformSpawnSons(g, idgen, owner, n, barrier)
  414. of nkCallKinds:
  415. if getMagic(n) == mSpawn:
  416. result = transformSlices(g, idgen, n)
  417. return wrapProcForSpawn(g, idgen, owner, result, n.typ, barrier, nil)
  418. result = transformSpawnSons(g, idgen, owner, n, barrier)
  419. elif n.safeLen > 0:
  420. result = transformSpawnSons(g, idgen, owner, n, barrier)
  421. else:
  422. result = n
  423. proc checkArgs(a: var AnalysisCtx; n: PNode) =
  424. discard "to implement"
  425. proc generateAliasChecks(a: AnalysisCtx; result: PNode) =
  426. discard "to implement"
  427. proc liftParallel*(g: ModuleGraph; idgen: IdGenerator; owner: PSym; n: PNode): PNode =
  428. # this needs to be called after the 'for' loop elimination
  429. # first pass:
  430. # - detect monotonic local integer variables
  431. # - detect used slices
  432. # - detect used arguments
  433. #echo "PAR ", renderTree(n)
  434. var a = initAnalysisCtx(g)
  435. let body = n.lastSon
  436. analyse(a, body)
  437. if a.spawns == 0:
  438. localError(g.config, n.info, "'parallel' section without 'spawn'")
  439. checkSlicesAreDisjoint(a)
  440. checkArgs(a, body)
  441. var varSection = newNodeI(nkVarSection, n.info)
  442. var temp = newSym(skTemp, getIdent(g.cache, "barrier"), nextSymId idgen, owner, n.info)
  443. temp.typ = magicsys.getCompilerProc(g, "Barrier").typ
  444. incl(temp.flags, sfFromGeneric)
  445. let tempNode = newSymNode(temp)
  446. varSection.addVar tempNode
  447. let barrier = genAddrOf(tempNode, idgen)
  448. result = newNodeI(nkStmtList, n.info)
  449. generateAliasChecks(a, result)
  450. result.add varSection
  451. result.add callCodegenProc(g, "openBarrier", barrier.info, barrier)
  452. result.add transformSpawn(g, idgen, owner, body, barrier)
  453. result.add callCodegenProc(g, "closeBarrier", barrier.info, barrier)