semparallel.nim 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. #
  2. #
  3. # The Nim Compiler
  4. # (c) Copyright 2015 Andreas Rumpf
  5. #
  6. # See the file "copying.txt", included in this
  7. # distribution, for details about the copyright.
  8. #
  9. ## Semantic checking for 'parallel'.
  10. # - codegen needs to support mSlice (+)
  11. # - lowerings must not perform unnecessary copies (+)
  12. # - slices should become "nocopy" to openArray (+)
  13. # - need to perform bound checks (+)
  14. #
  15. # - parallel needs to insert a barrier (+)
  16. # - passed arguments need to be ensured to be "const"
  17. # - what about 'f(a)'? --> f shouldn't have side effects anyway
  18. # - passed arrays need to be ensured not to alias
  19. # - passed slices need to be ensured to be disjoint (+)
  20. # - output slices need special logic (+)
  21. import
  22. ast, astalgo, idents, lowerings, magicsys, guards, sempass2, msgs,
  23. renderer, types, modulegraphs, options
  24. from trees import getMagic
  25. from strutils import `%`
  26. discard """
  27. one major problem:
  28. spawn f(a[i])
  29. inc i
  30. spawn f(a[i])
  31. is valid, but
  32. spawn f(a[i])
  33. spawn f(a[i])
  34. inc i
  35. is not! However,
  36. spawn f(a[i])
  37. if guard: inc i
  38. spawn f(a[i])
  39. is not valid either! --> We need a flow dependent analysis here.
  40. However:
  41. while foo:
  42. spawn f(a[i])
  43. inc i
  44. spawn f(a[i])
  45. Is not valid either! --> We should really restrict 'inc' to loop endings?
  46. The heuristic that we implement here (that has no false positives) is: Usage
  47. of 'i' in a slice *after* we determined the stride is invalid!
  48. """
  49. type
  50. TDirection = enum
  51. ascending, descending
  52. MonotonicVar = object
  53. v, alias: PSym # to support the ordinary 'countup' iterator
  54. # we need to detect aliases
  55. lower, upper, stride: PNode
  56. dir: TDirection
  57. blacklisted: bool # blacklisted variables that are not monotonic
  58. AnalysisCtx = object
  59. locals: seq[MonotonicVar]
  60. slices: seq[tuple[x,a,b: PNode, spawnId: int, inLoop: bool]]
  61. guards: TModel # nested guards
  62. args: seq[PSym] # args must be deeply immutable
  63. spawns: int # we can check that at last 1 spawn is used in
  64. # the 'parallel' section
  65. currentSpawnId: int
  66. inLoop: int
  67. graph: ModuleGraph
  68. proc initAnalysisCtx(g: ModuleGraph): AnalysisCtx =
  69. result.locals = @[]
  70. result.slices = @[]
  71. result.args = @[]
  72. result.guards.s = @[]
  73. result.guards.o = initOperators(g)
  74. result.graph = g
  75. proc lookupSlot(c: AnalysisCtx; s: PSym): int =
  76. for i in 0..<c.locals.len:
  77. if c.locals[i].v == s or c.locals[i].alias == s: return i
  78. return -1
  79. proc getSlot(c: var AnalysisCtx; v: PSym): ptr MonotonicVar =
  80. let s = lookupSlot(c, v)
  81. if s >= 0: return addr(c.locals[s])
  82. let L = c.locals.len
  83. c.locals.setLen(L+1)
  84. c.locals[L].v = v
  85. return addr(c.locals[L])
  86. proc gatherArgs(c: var AnalysisCtx; n: PNode) =
  87. for i in 0..<n.safeLen:
  88. let root = getRoot n[i]
  89. if root != nil:
  90. block addRoot:
  91. for r in items(c.args):
  92. if r == root: break addRoot
  93. c.args.add root
  94. gatherArgs(c, n[i])
  95. proc isSingleAssignable(n: PNode): bool =
  96. n.kind == nkSym and (let s = n.sym;
  97. s.kind in {skTemp, skForVar, skLet} and
  98. {sfAddrTaken, sfGlobal} * s.flags == {})
  99. proc isLocal(n: PNode): bool =
  100. n.kind == nkSym and (let s = n.sym;
  101. s.kind in {skResult, skTemp, skForVar, skVar, skLet} and
  102. {sfAddrTaken, sfGlobal} * s.flags == {})
  103. proc checkLocal(c: AnalysisCtx; n: PNode) =
  104. if isLocal(n):
  105. let s = c.lookupSlot(n.sym)
  106. if s >= 0 and c.locals[s].stride != nil:
  107. localError(c.graph.config, n.info, "invalid usage of counter after increment")
  108. else:
  109. for i in 0 ..< n.safeLen: checkLocal(c, n.sons[i])
  110. template `?`(x): untyped = x.renderTree
  111. proc checkLe(c: AnalysisCtx; a, b: PNode) =
  112. case proveLe(c.guards, a, b)
  113. of impUnknown:
  114. localError(c.graph.config, a.info, "cannot prove: " & ?a & " <= " & ?b & " (bounds check)")
  115. of impYes: discard
  116. of impNo:
  117. localError(c.graph.config, a.info, "can prove: " & ?a & " > " & ?b & " (bounds check)")
  118. proc checkBounds(c: AnalysisCtx; arr, idx: PNode) =
  119. checkLe(c, lowBound(c.graph.config, arr), idx)
  120. checkLe(c, idx, highBound(c.graph.config, arr, c.guards.o))
  121. proc addLowerBoundAsFacts(c: var AnalysisCtx) =
  122. for v in c.locals:
  123. if not v.blacklisted:
  124. c.guards.addFactLe(v.lower, newSymNode(v.v))
  125. proc addSlice(c: var AnalysisCtx; n: PNode; x, le, ri: PNode) =
  126. checkLocal(c, n)
  127. let le = le.canon(c.guards.o)
  128. let ri = ri.canon(c.guards.o)
  129. # perform static bounds checking here; and not later!
  130. let oldState = c.guards.s.len
  131. addLowerBoundAsFacts(c)
  132. c.checkBounds(x, le)
  133. c.checkBounds(x, ri)
  134. c.guards.s.setLen(oldState)
  135. c.slices.add((x, le, ri, c.currentSpawnId, c.inLoop > 0))
  136. proc overlap(m: TModel; conf: ConfigRef; x,y,c,d: PNode) =
  137. # X..Y and C..D overlap iff (X <= D and C <= Y)
  138. case proveLe(m, c, y)
  139. of impUnknown:
  140. case proveLe(m, x, d)
  141. of impNo: discard
  142. of impUnknown, impYes:
  143. localError(conf, x.info,
  144. "cannot prove: $# > $#; required for ($#)..($#) disjoint from ($#)..($#)" %
  145. [?c, ?y, ?x, ?y, ?c, ?d])
  146. of impYes:
  147. case proveLe(m, x, d)
  148. of impUnknown:
  149. localError(conf, x.info,
  150. "cannot prove: $# > $#; required for ($#)..($#) disjoint from ($#)..($#)" %
  151. [?x, ?d, ?x, ?y, ?c, ?d])
  152. of impYes:
  153. localError(conf, x.info, "($#)..($#) not disjoint from ($#)..($#)" %
  154. [?c, ?y, ?x, ?y, ?c, ?d])
  155. of impNo: discard
  156. of impNo: discard
  157. proc stride(c: AnalysisCtx; n: PNode): BiggestInt =
  158. if isLocal(n):
  159. let s = c.lookupSlot(n.sym)
  160. if s >= 0 and c.locals[s].stride != nil:
  161. result = c.locals[s].stride.intVal
  162. else:
  163. for i in 0 ..< n.safeLen: result += stride(c, n.sons[i])
  164. proc subStride(c: AnalysisCtx; n: PNode): PNode =
  165. # substitute with stride:
  166. if isLocal(n):
  167. let s = c.lookupSlot(n.sym)
  168. if s >= 0 and c.locals[s].stride != nil:
  169. result = buildAdd(n, c.locals[s].stride.intVal, c.guards.o)
  170. else:
  171. result = n
  172. elif n.safeLen > 0:
  173. result = shallowCopy(n)
  174. for i in 0 ..< n.len: result.sons[i] = subStride(c, n.sons[i])
  175. else:
  176. result = n
  177. proc checkSlicesAreDisjoint(c: var AnalysisCtx) =
  178. # this is the only thing that we need to perform after we have traversed
  179. # the whole tree so that the strides are available.
  180. # First we need to add all the computed lower bounds:
  181. addLowerBoundAsFacts(c)
  182. # Every slice used in a loop needs to be disjoint with itself:
  183. for x,a,b,id,inLoop in items(c.slices):
  184. if inLoop: overlap(c.guards, c.graph.config, a,b, c.subStride(a), c.subStride(b))
  185. # Another tricky example is:
  186. # while true:
  187. # spawn f(a[i])
  188. # spawn f(a[i+1])
  189. # inc i # inc i, 2 would be correct here
  190. #
  191. # Or even worse:
  192. # while true:
  193. # spawn f(a[i+1 .. i+3])
  194. # spawn f(a[i+4 .. i+5])
  195. # inc i, 4
  196. # Prove that i*k*stride + 3 != i*k'*stride + 5
  197. # For the correct example this amounts to
  198. # i*k*2 != i*k'*2 + 1
  199. # which is true.
  200. # For now, we don't try to prove things like that at all, even though it'd
  201. # be feasible for many useful examples. Instead we attach the slice to
  202. # a spawn and if the attached spawns differ, we bail out:
  203. for i in 0 .. high(c.slices):
  204. for j in i+1 .. high(c.slices):
  205. let x = c.slices[i]
  206. let y = c.slices[j]
  207. if x.spawnId != y.spawnId and guards.sameTree(x.x, y.x):
  208. if not x.inLoop or not y.inLoop:
  209. # XXX strictly speaking, 'or' is not correct here and it needs to
  210. # be 'and'. However this prevents too many obviously correct programs
  211. # like f(a[0..x]); for i in x+1 .. a.high: f(a[i])
  212. overlap(c.guards, c.graph.config, x.a, x.b, y.a, y.b)
  213. elif (let k = simpleSlice(x.a, x.b); let m = simpleSlice(y.a, y.b);
  214. k >= 0 and m >= 0):
  215. # ah I cannot resist the temptation and add another sweet heuristic:
  216. # if both slices have the form (i+k)..(i+k) and (i+m)..(i+m) we
  217. # check they are disjoint and k < stride and m < stride:
  218. overlap(c.guards, c.graph.config, x.a, x.b, y.a, y.b)
  219. let stride = min(c.stride(x.a), c.stride(y.a))
  220. if k < stride and m < stride:
  221. discard
  222. else:
  223. localError(c.graph.config, x.x.info, "cannot prove ($#)..($#) disjoint from ($#)..($#)" %
  224. [?x.a, ?x.b, ?y.a, ?y.b])
  225. else:
  226. localError(c.graph.config, x.x.info, "cannot prove ($#)..($#) disjoint from ($#)..($#)" %
  227. [?x.a, ?x.b, ?y.a, ?y.b])
  228. proc analyse(c: var AnalysisCtx; n: PNode)
  229. proc analyseSons(c: var AnalysisCtx; n: PNode) =
  230. for i in 0 ..< safeLen(n): analyse(c, n[i])
  231. proc min(a, b: PNode): PNode =
  232. if a.isNil: result = b
  233. elif a.intVal < b.intVal: result = a
  234. else: result = b
  235. proc fromSystem(op: PSym): bool = sfSystemModule in getModule(op).flags
  236. template pushSpawnId(c, body) {.dirty.} =
  237. inc c.spawns
  238. let oldSpawnId = c.currentSpawnId
  239. c.currentSpawnId = c.spawns
  240. body
  241. c.currentSpawnId = oldSpawnId
  242. proc analyseCall(c: var AnalysisCtx; n: PNode; op: PSym) =
  243. if op.magic == mSpawn:
  244. pushSpawnId(c):
  245. gatherArgs(c, n[1])
  246. analyseSons(c, n)
  247. elif op.magic == mInc or (op.name.s == "+=" and op.fromSystem):
  248. if n[1].isLocal:
  249. let incr = n[2].skipConv
  250. if incr.kind in {nkCharLit..nkUInt32Lit} and incr.intVal > 0:
  251. let slot = c.getSlot(n[1].sym)
  252. slot.stride = min(slot.stride, incr)
  253. analyseSons(c, n)
  254. elif op.name.s == "[]" and op.fromSystem:
  255. let slice = n[2].skipStmtList
  256. c.addSlice(n, n[1], slice[1], slice[2])
  257. analyseSons(c, n)
  258. elif op.name.s == "[]=" and op.fromSystem:
  259. let slice = n[2].skipStmtList
  260. c.addSlice(n, n[1], slice[1], slice[2])
  261. analyseSons(c, n)
  262. else:
  263. analyseSons(c, n)
  264. proc analyseCase(c: var AnalysisCtx; n: PNode) =
  265. analyse(c, n.sons[0])
  266. let oldFacts = c.guards.s.len
  267. for i in 1..<n.len:
  268. let branch = n.sons[i]
  269. setLen(c.guards.s, oldFacts)
  270. addCaseBranchFacts(c.guards, n, i)
  271. for i in 0 ..< branch.len:
  272. analyse(c, branch.sons[i])
  273. setLen(c.guards.s, oldFacts)
  274. proc analyseIf(c: var AnalysisCtx; n: PNode) =
  275. analyse(c, n.sons[0].sons[0])
  276. let oldFacts = c.guards.s.len
  277. addFact(c.guards, canon(n.sons[0].sons[0], c.guards.o))
  278. analyse(c, n.sons[0].sons[1])
  279. for i in 1..<n.len:
  280. let branch = n.sons[i]
  281. setLen(c.guards.s, oldFacts)
  282. for j in 0..i-1:
  283. addFactNeg(c.guards, canon(n.sons[j].sons[0], c.guards.o))
  284. if branch.len > 1:
  285. addFact(c.guards, canon(branch.sons[0], c.guards.o))
  286. for i in 0 ..< branch.len:
  287. analyse(c, branch.sons[i])
  288. setLen(c.guards.s, oldFacts)
  289. proc analyse(c: var AnalysisCtx; n: PNode) =
  290. case n.kind
  291. of nkAsgn, nkFastAsgn:
  292. let y = n[1].skipConv
  293. if n[0].isSingleAssignable and y.isLocal:
  294. let slot = c.getSlot(y.sym)
  295. slot.alias = n[0].sym
  296. elif n[0].isLocal:
  297. # since we already ensure sfAddrTaken is not in s.flags, we only need to
  298. # prevent direct assignments to the monotonic variable:
  299. let slot = c.getSlot(n[0].sym)
  300. slot.blacklisted = true
  301. invalidateFacts(c.guards, n[0])
  302. let value = n[1]
  303. if getMagic(value) == mSpawn:
  304. pushSpawnId(c):
  305. gatherArgs(c, value[1])
  306. analyseSons(c, value[1])
  307. analyse(c, n[0])
  308. else:
  309. analyseSons(c, n)
  310. addAsgnFact(c.guards, n[0], y)
  311. of nkCallKinds:
  312. # direct call:
  313. if n[0].kind == nkSym: analyseCall(c, n, n[0].sym)
  314. else: analyseSons(c, n)
  315. of nkBracketExpr:
  316. if n[0].typ != nil and skipTypes(n[0].typ, abstractVar).kind != tyTuple:
  317. c.addSlice(n, n[0], n[1], n[1])
  318. analyseSons(c, n)
  319. of nkReturnStmt, nkRaiseStmt, nkTryStmt:
  320. localError(c.graph.config, n.info, "invalid control flow for 'parallel'")
  321. # 'break' that leaves the 'parallel' section is not valid either
  322. # or maybe we should generate a 'try' XXX
  323. of nkVarSection, nkLetSection:
  324. for it in n:
  325. let value = it.lastSon
  326. let isSpawned = getMagic(value) == mSpawn
  327. if isSpawned:
  328. pushSpawnId(c):
  329. gatherArgs(c, value[1])
  330. analyseSons(c, value[1])
  331. if value.kind != nkEmpty:
  332. for j in 0 .. it.len-3:
  333. if it[j].isLocal:
  334. let slot = c.getSlot(it[j].sym)
  335. if slot.lower.isNil: slot.lower = value
  336. else: internalError(c.graph.config, it.info, "slot already has a lower bound")
  337. if not isSpawned: analyse(c, value)
  338. of nkCaseStmt: analyseCase(c, n)
  339. of nkWhen, nkIfStmt, nkIfExpr: analyseIf(c, n)
  340. of nkWhileStmt:
  341. analyse(c, n.sons[0])
  342. # 'while true' loop?
  343. inc c.inLoop
  344. if isTrue(n.sons[0]):
  345. analyseSons(c, n.sons[1])
  346. else:
  347. # loop may never execute:
  348. let oldState = c.locals.len
  349. let oldFacts = c.guards.s.len
  350. addFact(c.guards, canon(n.sons[0], c.guards.o))
  351. analyse(c, n.sons[1])
  352. setLen(c.locals, oldState)
  353. setLen(c.guards.s, oldFacts)
  354. # we know after the loop the negation holds:
  355. if not hasSubnodeWith(n.sons[1], nkBreakStmt):
  356. addFactNeg(c.guards, canon(n.sons[0], c.guards.o))
  357. dec c.inLoop
  358. of nkTypeSection, nkProcDef, nkConverterDef, nkMethodDef, nkIteratorDef,
  359. nkMacroDef, nkTemplateDef, nkConstSection, nkPragma, nkFuncDef:
  360. discard
  361. else:
  362. analyseSons(c, n)
  363. proc transformSlices(g: ModuleGraph; n: PNode): PNode =
  364. if n.kind in nkCallKinds and n[0].kind == nkSym:
  365. let op = n[0].sym
  366. if op.name.s == "[]" and op.fromSystem:
  367. result = copyNode(n)
  368. let opSlice = newSymNode(createMagic(g, "slice", mSlice))
  369. opSlice.typ = getSysType(g, n.info, tyInt)
  370. result.add opSlice
  371. result.add n[1]
  372. let slice = n[2].skipStmtList
  373. result.add slice[1]
  374. result.add slice[2]
  375. return result
  376. if n.safeLen > 0:
  377. result = shallowCopy(n)
  378. for i in 0 ..< n.len:
  379. result.sons[i] = transformSlices(g, n.sons[i])
  380. else:
  381. result = n
  382. proc transformSpawn(g: ModuleGraph; owner: PSym; n, barrier: PNode): PNode
  383. proc transformSpawnSons(g: ModuleGraph; owner: PSym; n, barrier: PNode): PNode =
  384. result = shallowCopy(n)
  385. for i in 0 ..< n.len:
  386. result.sons[i] = transformSpawn(g, owner, n.sons[i], barrier)
  387. proc transformSpawn(g: ModuleGraph; owner: PSym; n, barrier: PNode): PNode =
  388. case n.kind
  389. of nkVarSection, nkLetSection:
  390. result = nil
  391. for it in n:
  392. let b = it.lastSon
  393. if getMagic(b) == mSpawn:
  394. if it.len != 3: localError(g.config, it.info, "invalid context for 'spawn'")
  395. let m = transformSlices(g, b)
  396. if result.isNil:
  397. result = newNodeI(nkStmtList, n.info)
  398. result.add n
  399. let t = b[1][0].typ.sons[0]
  400. if spawnResult(t, true) == srByVar:
  401. result.add wrapProcForSpawn(g, owner, m, b.typ, barrier, it[0])
  402. it.sons[it.len-1] = newNodeI(nkEmpty, it.info)
  403. else:
  404. it.sons[it.len-1] = wrapProcForSpawn(g, owner, m, b.typ, barrier, nil)
  405. if result.isNil: result = n
  406. of nkAsgn, nkFastAsgn:
  407. let b = n[1]
  408. if getMagic(b) == mSpawn and (let t = b[1][0].typ.sons[0];
  409. spawnResult(t, true) == srByVar):
  410. let m = transformSlices(g, b)
  411. return wrapProcForSpawn(g, owner, m, b.typ, barrier, n[0])
  412. result = transformSpawnSons(g, owner, n, barrier)
  413. of nkCallKinds:
  414. if getMagic(n) == mSpawn:
  415. result = transformSlices(g, n)
  416. return wrapProcForSpawn(g, owner, result, n.typ, barrier, nil)
  417. result = transformSpawnSons(g, owner, n, barrier)
  418. elif n.safeLen > 0:
  419. result = transformSpawnSons(g, owner, n, barrier)
  420. else:
  421. result = n
  422. proc checkArgs(a: var AnalysisCtx; n: PNode) =
  423. discard "too implement"
  424. proc generateAliasChecks(a: AnalysisCtx; result: PNode) =
  425. discard "too implement"
  426. proc liftParallel*(g: ModuleGraph; owner: PSym; n: PNode): PNode =
  427. # this needs to be called after the 'for' loop elimination
  428. # first pass:
  429. # - detect monotonic local integer variables
  430. # - detect used slices
  431. # - detect used arguments
  432. #echo "PAR ", renderTree(n)
  433. var a = initAnalysisCtx(g)
  434. let body = n.lastSon
  435. analyse(a, body)
  436. if a.spawns == 0:
  437. localError(g.config, n.info, "'parallel' section without 'spawn'")
  438. checkSlicesAreDisjoint(a)
  439. checkArgs(a, body)
  440. var varSection = newNodeI(nkVarSection, n.info)
  441. var temp = newSym(skTemp, getIdent(g.cache, "barrier"), owner, n.info)
  442. temp.typ = magicsys.getCompilerProc(g, "Barrier").typ
  443. incl(temp.flags, sfFromGeneric)
  444. let tempNode = newSymNode(temp)
  445. varSection.addVar tempNode
  446. let barrier = genAddrOf(tempNode)
  447. result = newNodeI(nkStmtList, n.info)
  448. generateAliasChecks(a, result)
  449. result.add varSection
  450. result.add callCodegenProc(g, "openBarrier", barrier.info, barrier)
  451. result.add transformSpawn(g, owner, body, barrier)
  452. result.add callCodegenProc(g, "closeBarrier", barrier.info, barrier)