semparallel.nim 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505
  1. #
  2. #
  3. # The Nim Compiler
  4. # (c) Copyright 2015 Andreas Rumpf
  5. #
  6. # See the file "copying.txt", included in this
  7. # distribution, for details about the copyright.
  8. #
  9. ## Semantic checking for 'parallel'.
  10. # - codegen needs to support mSlice (+)
  11. # - lowerings must not perform unnecessary copies (+)
  12. # - slices should become "nocopy" to openArray (+)
  13. # - need to perform bound checks (+)
  14. #
  15. # - parallel needs to insert a barrier (+)
  16. # - passed arguments need to be ensured to be "const"
  17. # - what about 'f(a)'? --> f shouldn't have side effects anyway
  18. # - passed arrays need to be ensured not to alias
  19. # - passed slices need to be ensured to be disjoint (+)
  20. # - output slices need special logic (+)
  21. import
  22. ast, astalgo, idents, lowerings, magicsys, guards, msgs,
  23. renderer, types, modulegraphs, options, spawn, lineinfos
  24. from trees import getMagic, getRoot
  25. from std/strutils import `%`
  26. discard """
  27. one major problem:
  28. spawn f(a[i])
  29. inc i
  30. spawn f(a[i])
  31. is valid, but
  32. spawn f(a[i])
  33. spawn f(a[i])
  34. inc i
  35. is not! However,
  36. spawn f(a[i])
  37. if guard: inc i
  38. spawn f(a[i])
  39. is not valid either! --> We need a flow dependent analysis here.
  40. However:
  41. while foo:
  42. spawn f(a[i])
  43. inc i
  44. spawn f(a[i])
  45. Is not valid either! --> We should really restrict 'inc' to loop endings?
  46. The heuristic that we implement here (that has no false positives) is: Usage
  47. of 'i' in a slice *after* we determined the stride is invalid!
  48. """
  49. type
  50. TDirection = enum
  51. ascending, descending
  52. MonotonicVar = object
  53. v, alias: PSym # to support the ordinary 'countup' iterator
  54. # we need to detect aliases
  55. lower, upper, stride: PNode
  56. dir: TDirection
  57. blacklisted: bool # blacklisted variables that are not monotonic
  58. AnalysisCtx = object
  59. locals: seq[MonotonicVar]
  60. slices: seq[tuple[x,a,b: PNode, spawnId: int, inLoop: bool]]
  61. guards: TModel # nested guards
  62. args: seq[PSym] # args must be deeply immutable
  63. spawns: int # we can check that at last 1 spawn is used in
  64. # the 'parallel' section
  65. currentSpawnId: int
  66. inLoop: int
  67. graph: ModuleGraph
  68. proc initAnalysisCtx(g: ModuleGraph): AnalysisCtx =
  69. result = AnalysisCtx(locals: @[],
  70. slices: @[],
  71. args: @[],
  72. graph: g)
  73. result.guards.s = @[]
  74. result.guards.g = g
  75. proc lookupSlot(c: AnalysisCtx; s: PSym): int =
  76. for i in 0..<c.locals.len:
  77. if c.locals[i].v == s or c.locals[i].alias == s: return i
  78. return -1
  79. proc getSlot(c: var AnalysisCtx; v: PSym): ptr MonotonicVar =
  80. let s = lookupSlot(c, v)
  81. if s >= 0: return addr(c.locals[s])
  82. c.locals.setLen(c.locals.len+1)
  83. c.locals[^1].v = v
  84. return addr(c.locals[^1])
  85. proc gatherArgs(c: var AnalysisCtx; n: PNode) =
  86. for i in 0..<n.safeLen:
  87. let root = getRoot n[i]
  88. if root != nil:
  89. block addRoot:
  90. for r in items(c.args):
  91. if r == root: break addRoot
  92. c.args.add root
  93. gatherArgs(c, n[i])
  94. proc isSingleAssignable(n: PNode): bool =
  95. n.kind == nkSym and (let s = n.sym;
  96. s.kind in {skTemp, skForVar, skLet} and
  97. {sfAddrTaken, sfGlobal} * s.flags == {})
  98. proc isLocal(n: PNode): bool =
  99. n.kind == nkSym and (let s = n.sym;
  100. s.kind in {skResult, skTemp, skForVar, skVar, skLet} and
  101. {sfAddrTaken, sfGlobal} * s.flags == {})
  102. proc checkLocal(c: AnalysisCtx; n: PNode) =
  103. if isLocal(n):
  104. let s = c.lookupSlot(n.sym)
  105. if s >= 0 and c.locals[s].stride != nil:
  106. localError(c.graph.config, n.info, "invalid usage of counter after increment")
  107. else:
  108. for i in 0..<n.safeLen: checkLocal(c, n[i])
  109. template `?`(x): untyped = x.renderTree
  110. proc checkLe(c: AnalysisCtx; a, b: PNode) =
  111. case proveLe(c.guards, a, b)
  112. of impUnknown:
  113. message(c.graph.config, a.info, warnStaticIndexCheck,
  114. "cannot prove: " & ?a & " <= " & ?b)
  115. of impYes: discard
  116. of impNo:
  117. message(c.graph.config, a.info, warnStaticIndexCheck,
  118. "can prove: " & ?a & " > " & ?b)
  119. proc checkBounds(c: AnalysisCtx; arr, idx: PNode) =
  120. checkLe(c, lowBound(c.graph.config, arr), idx)
  121. checkLe(c, idx, highBound(c.graph.config, arr, c.graph.operators))
  122. proc addLowerBoundAsFacts(c: var AnalysisCtx) =
  123. for v in c.locals:
  124. if not v.blacklisted:
  125. c.guards.addFactLe(v.lower, newSymNode(v.v))
  126. proc addSlice(c: var AnalysisCtx; n: PNode; x, le, ri: PNode) =
  127. checkLocal(c, n)
  128. let le = le.canon(c.graph.operators)
  129. let ri = ri.canon(c.graph.operators)
  130. # perform static bounds checking here; and not later!
  131. let oldState = c.guards.s.len
  132. addLowerBoundAsFacts(c)
  133. c.checkBounds(x, le)
  134. c.checkBounds(x, ri)
  135. c.guards.s.setLen(oldState)
  136. c.slices.add((x, le, ri, c.currentSpawnId, c.inLoop > 0))
  137. proc overlap(m: TModel; conf: ConfigRef; x,y,c,d: PNode) =
  138. # X..Y and C..D overlap iff (X <= D and C <= Y)
  139. case proveLe(m, c, y)
  140. of impUnknown:
  141. case proveLe(m, x, d)
  142. of impNo: discard
  143. of impUnknown, impYes:
  144. message(conf, x.info, warnStaticIndexCheck,
  145. "cannot prove: $# > $#; required for ($#)..($#) disjoint from ($#)..($#)" %
  146. [?c, ?y, ?x, ?y, ?c, ?d])
  147. of impYes:
  148. case proveLe(m, x, d)
  149. of impUnknown:
  150. message(conf, x.info, warnStaticIndexCheck,
  151. "cannot prove: $# > $#; required for ($#)..($#) disjoint from ($#)..($#)" %
  152. [?x, ?d, ?x, ?y, ?c, ?d])
  153. of impYes:
  154. message(conf, x.info, warnStaticIndexCheck, "($#)..($#) not disjoint from ($#)..($#)" %
  155. [?c, ?y, ?x, ?y, ?c, ?d])
  156. of impNo: discard
  157. of impNo: discard
  158. proc stride(c: AnalysisCtx; n: PNode): BiggestInt =
  159. if isLocal(n):
  160. let s = c.lookupSlot(n.sym)
  161. if s >= 0 and c.locals[s].stride != nil:
  162. result = c.locals[s].stride.intVal
  163. else:
  164. result = 0
  165. else:
  166. result = 0
  167. for i in 0..<n.safeLen: result += stride(c, n[i])
  168. proc subStride(c: AnalysisCtx; n: PNode): PNode =
  169. # substitute with stride:
  170. if isLocal(n):
  171. let s = c.lookupSlot(n.sym)
  172. if s >= 0 and c.locals[s].stride != nil:
  173. result = buildAdd(n, c.locals[s].stride.intVal, c.graph.operators)
  174. else:
  175. result = n
  176. elif n.safeLen > 0:
  177. result = shallowCopy(n)
  178. for i in 0..<n.len: result[i] = subStride(c, n[i])
  179. else:
  180. result = n
  181. proc checkSlicesAreDisjoint(c: var AnalysisCtx) =
  182. # this is the only thing that we need to perform after we have traversed
  183. # the whole tree so that the strides are available.
  184. # First we need to add all the computed lower bounds:
  185. addLowerBoundAsFacts(c)
  186. # Every slice used in a loop needs to be disjoint with itself:
  187. for x,a,b,id,inLoop in items(c.slices):
  188. if inLoop: overlap(c.guards, c.graph.config, a,b, c.subStride(a), c.subStride(b))
  189. # Another tricky example is:
  190. # while true:
  191. # spawn f(a[i])
  192. # spawn f(a[i+1])
  193. # inc i # inc i, 2 would be correct here
  194. #
  195. # Or even worse:
  196. # while true:
  197. # spawn f(a[i+1..i+3])
  198. # spawn f(a[i+4..i+5])
  199. # inc i, 4
  200. # Prove that i*k*stride + 3 != i*k'*stride + 5
  201. # For the correct example this amounts to
  202. # i*k*2 != i*k'*2 + 1
  203. # which is true.
  204. # For now, we don't try to prove things like that at all, even though it'd
  205. # be feasible for many useful examples. Instead we attach the slice to
  206. # a spawn and if the attached spawns differ, we bail out:
  207. for i in 0..high(c.slices):
  208. for j in i+1..high(c.slices):
  209. let x = c.slices[i]
  210. let y = c.slices[j]
  211. if x.spawnId != y.spawnId and guards.sameTree(x.x, y.x):
  212. if not x.inLoop or not y.inLoop:
  213. # XXX strictly speaking, 'or' is not correct here and it needs to
  214. # be 'and'. However this prevents too many obviously correct programs
  215. # like f(a[0..x]); for i in x+1..a.high: f(a[i])
  216. overlap(c.guards, c.graph.config, x.a, x.b, y.a, y.b)
  217. elif (let k = simpleSlice(x.a, x.b); let m = simpleSlice(y.a, y.b);
  218. k >= 0 and m >= 0):
  219. # ah I cannot resist the temptation and add another sweet heuristic:
  220. # if both slices have the form (i+k)..(i+k) and (i+m)..(i+m) we
  221. # check they are disjoint and k < stride and m < stride:
  222. overlap(c.guards, c.graph.config, x.a, x.b, y.a, y.b)
  223. let stride = min(c.stride(x.a), c.stride(y.a))
  224. if k < stride and m < stride:
  225. discard
  226. else:
  227. localError(c.graph.config, x.x.info, "cannot prove ($#)..($#) disjoint from ($#)..($#)" %
  228. [?x.a, ?x.b, ?y.a, ?y.b])
  229. else:
  230. localError(c.graph.config, x.x.info, "cannot prove ($#)..($#) disjoint from ($#)..($#)" %
  231. [?x.a, ?x.b, ?y.a, ?y.b])
  232. proc analyse(c: var AnalysisCtx; n: PNode)
  233. proc analyseSons(c: var AnalysisCtx; n: PNode) =
  234. for i in 0..<n.safeLen: analyse(c, n[i])
  235. proc min(a, b: PNode): PNode =
  236. if a.isNil: result = b
  237. elif a.intVal < b.intVal: result = a
  238. else: result = b
  239. template pushSpawnId(c, body) {.dirty.} =
  240. inc c.spawns
  241. let oldSpawnId = c.currentSpawnId
  242. c.currentSpawnId = c.spawns
  243. body
  244. c.currentSpawnId = oldSpawnId
  245. proc analyseCall(c: var AnalysisCtx; n: PNode; op: PSym) =
  246. if op.magic == mSpawn:
  247. pushSpawnId(c):
  248. gatherArgs(c, n[1])
  249. analyseSons(c, n)
  250. elif op.magic == mInc or (op.name.s == "+=" and op.fromSystem):
  251. if n[1].isLocal:
  252. let incr = n[2].skipConv
  253. if incr.kind in {nkCharLit..nkUInt32Lit} and incr.intVal > 0:
  254. let slot = c.getSlot(n[1].sym)
  255. slot.stride = min(slot.stride, incr)
  256. analyseSons(c, n)
  257. elif op.name.s == "[]" and op.fromSystem:
  258. let slice = n[2].skipStmtList
  259. c.addSlice(n, n[1], slice[1], slice[2])
  260. analyseSons(c, n)
  261. elif op.name.s == "[]=" and op.fromSystem:
  262. let slice = n[2].skipStmtList
  263. c.addSlice(n, n[1], slice[1], slice[2])
  264. analyseSons(c, n)
  265. else:
  266. analyseSons(c, n)
  267. proc analyseCase(c: var AnalysisCtx; n: PNode) =
  268. analyse(c, n[0])
  269. let oldFacts = c.guards.s.len
  270. for i in 1..<n.len:
  271. let branch = n[i]
  272. setLen(c.guards.s, oldFacts)
  273. addCaseBranchFacts(c.guards, n, i)
  274. for i in 0..<branch.len:
  275. analyse(c, branch[i])
  276. setLen(c.guards.s, oldFacts)
  277. proc analyseIf(c: var AnalysisCtx; n: PNode) =
  278. analyse(c, n[0][0])
  279. let oldFacts = c.guards.s.len
  280. addFact(c.guards, canon(n[0][0], c.graph.operators))
  281. analyse(c, n[0][1])
  282. for i in 1..<n.len:
  283. let branch = n[i]
  284. setLen(c.guards.s, oldFacts)
  285. for j in 0..i-1:
  286. addFactNeg(c.guards, canon(n[j][0], c.graph.operators))
  287. if branch.len > 1:
  288. addFact(c.guards, canon(branch[0], c.graph.operators))
  289. for i in 0..<branch.len:
  290. analyse(c, branch[i])
  291. setLen(c.guards.s, oldFacts)
  292. proc analyse(c: var AnalysisCtx; n: PNode) =
  293. case n.kind
  294. of nkAsgn, nkFastAsgn, nkSinkAsgn:
  295. let y = n[1].skipConv
  296. if n[0].isSingleAssignable and y.isLocal:
  297. let slot = c.getSlot(y.sym)
  298. slot.alias = n[0].sym
  299. elif n[0].isLocal:
  300. # since we already ensure sfAddrTaken is not in s.flags, we only need to
  301. # prevent direct assignments to the monotonic variable:
  302. let slot = c.getSlot(n[0].sym)
  303. slot.blacklisted = true
  304. invalidateFacts(c.guards, n[0])
  305. let value = n[1]
  306. if getMagic(value) == mSpawn:
  307. pushSpawnId(c):
  308. gatherArgs(c, value[1])
  309. analyseSons(c, value[1])
  310. analyse(c, n[0])
  311. else:
  312. analyseSons(c, n)
  313. addAsgnFact(c.guards, n[0], y)
  314. of nkCallKinds:
  315. # direct call:
  316. if n[0].kind == nkSym: analyseCall(c, n, n[0].sym)
  317. else: analyseSons(c, n)
  318. of nkBracketExpr:
  319. if n[0].typ != nil and skipTypes(n[0].typ, abstractVar).kind != tyTuple:
  320. c.addSlice(n, n[0], n[1], n[1])
  321. analyseSons(c, n)
  322. of nkReturnStmt, nkRaiseStmt, nkTryStmt, nkHiddenTryStmt:
  323. localError(c.graph.config, n.info, "invalid control flow for 'parallel'")
  324. # 'break' that leaves the 'parallel' section is not valid either
  325. # or maybe we should generate a 'try' XXX
  326. of nkVarSection, nkLetSection:
  327. for it in n:
  328. let value = it.lastSon
  329. let isSpawned = getMagic(value) == mSpawn
  330. if isSpawned:
  331. pushSpawnId(c):
  332. gatherArgs(c, value[1])
  333. analyseSons(c, value[1])
  334. if value.kind != nkEmpty:
  335. for j in 0..<it.len-2:
  336. if it[j].isLocal:
  337. let slot = c.getSlot(it[j].sym)
  338. if slot.lower.isNil: slot.lower = value
  339. else: internalError(c.graph.config, it.info, "slot already has a lower bound")
  340. if not isSpawned: analyse(c, value)
  341. of nkCaseStmt: analyseCase(c, n)
  342. of nkWhen, nkIfStmt, nkIfExpr: analyseIf(c, n)
  343. of nkWhileStmt:
  344. analyse(c, n[0])
  345. # 'while true' loop?
  346. inc c.inLoop
  347. if isTrue(n[0]):
  348. analyseSons(c, n[1])
  349. else:
  350. # loop may never execute:
  351. let oldState = c.locals.len
  352. let oldFacts = c.guards.s.len
  353. addFact(c.guards, canon(n[0], c.graph.operators))
  354. analyse(c, n[1])
  355. setLen(c.locals, oldState)
  356. setLen(c.guards.s, oldFacts)
  357. # we know after the loop the negation holds:
  358. if not hasSubnodeWith(n[1], nkBreakStmt):
  359. addFactNeg(c.guards, canon(n[0], c.graph.operators))
  360. dec c.inLoop
  361. of nkTypeSection, nkProcDef, nkConverterDef, nkMethodDef, nkIteratorDef,
  362. nkMacroDef, nkTemplateDef, nkConstSection, nkPragma, nkFuncDef,
  363. nkMixinStmt, nkBindStmt, nkExportStmt:
  364. discard
  365. else:
  366. analyseSons(c, n)
  367. proc transformSlices(g: ModuleGraph; idgen: IdGenerator; n: PNode): PNode =
  368. if n.kind in nkCallKinds and n[0].kind == nkSym:
  369. let op = n[0].sym
  370. if op.name.s == "[]" and op.fromSystem:
  371. result = copyNode(n)
  372. var typ = newType(tyOpenArray, idgen, result.typ.owner)
  373. typ.add result.typ.elementType
  374. result.typ = typ
  375. let opSlice = newSymNode(createMagic(g, idgen, "slice", mSlice))
  376. opSlice.typ = getSysType(g, n.info, tyInt)
  377. result.add opSlice
  378. result.add n[1]
  379. let slice = n[2].skipStmtList
  380. result.add slice[1]
  381. result.add slice[2]
  382. return result
  383. if n.safeLen > 0:
  384. result = shallowCopy(n)
  385. for i in 0..<n.len:
  386. result[i] = transformSlices(g, idgen, n[i])
  387. else:
  388. result = n
  389. proc transformSpawn(g: ModuleGraph; idgen: IdGenerator; owner: PSym; n, barrier: PNode): PNode
  390. proc transformSpawnSons(g: ModuleGraph; idgen: IdGenerator; owner: PSym; n, barrier: PNode): PNode =
  391. result = shallowCopy(n)
  392. for i in 0..<n.len:
  393. result[i] = transformSpawn(g, idgen, owner, n[i], barrier)
  394. proc transformSpawn(g: ModuleGraph; idgen: IdGenerator; owner: PSym; n, barrier: PNode): PNode =
  395. case n.kind
  396. of nkVarSection, nkLetSection:
  397. result = nil
  398. for it in n:
  399. let b = it.lastSon
  400. if getMagic(b) == mSpawn:
  401. if it.len != 3: localError(g.config, it.info, "invalid context for 'spawn'")
  402. let m = transformSlices(g, idgen, b)
  403. if result.isNil:
  404. result = newNodeI(nkStmtList, n.info)
  405. result.add n
  406. let t = b[1][0].typ.returnType
  407. if spawnResult(t, true) == srByVar:
  408. result.add wrapProcForSpawn(g, idgen, owner, m, b.typ, barrier, it[0])
  409. it[^1] = newNodeI(nkEmpty, it.info)
  410. else:
  411. it[^1] = wrapProcForSpawn(g, idgen, owner, m, b.typ, barrier, nil)
  412. if result.isNil: result = n
  413. of nkAsgn, nkFastAsgn, nkSinkAsgn:
  414. let b = n[1]
  415. if getMagic(b) == mSpawn and (let t = b[1][0].typ.returnType;
  416. spawnResult(t, true) == srByVar):
  417. let m = transformSlices(g, idgen, b)
  418. return wrapProcForSpawn(g, idgen, owner, m, b.typ, barrier, n[0])
  419. result = transformSpawnSons(g, idgen, owner, n, barrier)
  420. of nkCallKinds:
  421. if getMagic(n) == mSpawn:
  422. result = transformSlices(g, idgen, n)
  423. return wrapProcForSpawn(g, idgen, owner, result, n.typ, barrier, nil)
  424. result = transformSpawnSons(g, idgen, owner, n, barrier)
  425. elif n.safeLen > 0:
  426. result = transformSpawnSons(g, idgen, owner, n, barrier)
  427. else:
  428. result = n
  429. proc checkArgs(a: var AnalysisCtx; n: PNode) =
  430. discard "to implement"
  431. proc generateAliasChecks(a: AnalysisCtx; result: PNode) =
  432. discard "to implement"
  433. proc liftParallel*(g: ModuleGraph; idgen: IdGenerator; owner: PSym; n: PNode): PNode =
  434. # this needs to be called after the 'for' loop elimination
  435. # first pass:
  436. # - detect monotonic local integer variables
  437. # - detect used slices
  438. # - detect used arguments
  439. #echo "PAR ", renderTree(n)
  440. var a = initAnalysisCtx(g)
  441. let body = n.lastSon
  442. analyse(a, body)
  443. if a.spawns == 0:
  444. localError(g.config, n.info, "'parallel' section without 'spawn'")
  445. checkSlicesAreDisjoint(a)
  446. checkArgs(a, body)
  447. var varSection = newNodeI(nkVarSection, n.info)
  448. var temp = newSym(skTemp, getIdent(g.cache, "barrier"), idgen, owner, n.info)
  449. temp.typ = magicsys.getCompilerProc(g, "Barrier").typ
  450. incl(temp.flags, sfFromGeneric)
  451. let tempNode = newSymNode(temp)
  452. varSection.addVar tempNode
  453. let barrier = genAddrOf(tempNode, idgen)
  454. result = newNodeI(nkStmtList, n.info)
  455. generateAliasChecks(a, result)
  456. result.add varSection
  457. result.add callCodegenProc(g, "openBarrier", barrier.info, barrier)
  458. result.add transformSpawn(g, idgen, owner, body, barrier)
  459. result.add callCodegenProc(g, "closeBarrier", barrier.info, barrier)