intsets.nim 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. #
  2. #
  3. # Nim's Runtime Library
  4. # (c) Copyright 2012 Andreas Rumpf
  5. #
  6. # See the file "copying.txt", included in this
  7. # distribution, for details about the copyright.
  8. #
  9. ## The ``intsets`` module implements an efficient int set implemented as a
  10. ## `sparse bit set`:idx:.
  11. ## **Note**: Currently the assignment operator ``=`` for ``intsets``
  12. ## performs some rather meaningless shallow copy. Since Nim currently does
  13. ## not allow the assignment operator to be overloaded, use ``assign`` to
  14. ## get a deep copy.
  15. import
  16. hashes, math
  17. type
  18. BitScalar = int
  19. const
  20. InitIntSetSize = 8 # must be a power of two!
  21. TrunkShift = 9
  22. BitsPerTrunk = 1 shl TrunkShift # needs to be a power of 2 and
  23. # divisible by 64
  24. TrunkMask = BitsPerTrunk - 1
  25. IntsPerTrunk = BitsPerTrunk div (sizeof(BitScalar) * 8)
  26. IntShift = 5 + ord(sizeof(BitScalar) == 8) # 5 or 6, depending on int width
  27. IntMask = 1 shl IntShift - 1
  28. type
  29. PTrunk = ref Trunk
  30. Trunk = object
  31. next: PTrunk # all nodes are connected with this pointer
  32. key: int # start address at bit 0
  33. bits: array[0..IntsPerTrunk - 1, BitScalar] # a bit vector
  34. TrunkSeq = seq[PTrunk]
  35. IntSet* = object ## an efficient set of 'int' implemented as a sparse bit set
  36. elems: int # only valid for small numbers
  37. counter, max: int
  38. head: PTrunk
  39. data: TrunkSeq
  40. a: array[0..33, int] # profiling shows that 34 elements are enough
  41. proc mustRehash(length, counter: int): bool {.inline.} =
  42. assert(length > counter)
  43. result = (length * 2 < counter * 3) or (length - counter < 4)
  44. proc nextTry(h, maxHash: Hash): Hash {.inline.} =
  45. result = ((5 * h) + 1) and maxHash
  46. proc intSetGet(t: IntSet, key: int): PTrunk =
  47. var h = key and t.max
  48. while t.data[h] != nil:
  49. if t.data[h].key == key:
  50. return t.data[h]
  51. h = nextTry(h, t.max)
  52. result = nil
  53. proc intSetRawInsert(t: IntSet, data: var TrunkSeq, desc: PTrunk) =
  54. var h = desc.key and t.max
  55. while data[h] != nil:
  56. assert(data[h] != desc)
  57. h = nextTry(h, t.max)
  58. assert(data[h] == nil)
  59. data[h] = desc
  60. proc intSetEnlarge(t: var IntSet) =
  61. var n: TrunkSeq
  62. var oldMax = t.max
  63. t.max = ((t.max + 1) * 2) - 1
  64. newSeq(n, t.max + 1)
  65. for i in countup(0, oldMax):
  66. if t.data[i] != nil: intSetRawInsert(t, n, t.data[i])
  67. swap(t.data, n)
  68. proc intSetPut(t: var IntSet, key: int): PTrunk =
  69. var h = key and t.max
  70. while t.data[h] != nil:
  71. if t.data[h].key == key:
  72. return t.data[h]
  73. h = nextTry(h, t.max)
  74. if mustRehash(t.max + 1, t.counter): intSetEnlarge(t)
  75. inc(t.counter)
  76. h = key and t.max
  77. while t.data[h] != nil: h = nextTry(h, t.max)
  78. assert(t.data[h] == nil)
  79. new(result)
  80. result.next = t.head
  81. result.key = key
  82. t.head = result
  83. t.data[h] = result
  84. proc contains*(s: IntSet, key: int): bool =
  85. ## Returns true iff `key` is in `s`.
  86. if s.elems <= s.a.len:
  87. for i in 0..<s.elems:
  88. if s.a[i] == key: return true
  89. else:
  90. var t = intSetGet(s, `shr`(key, TrunkShift))
  91. if t != nil:
  92. var u = key and TrunkMask
  93. result = (t.bits[`shr`(u, IntShift)] and `shl`(1, u and IntMask)) != 0
  94. else:
  95. result = false
  96. iterator items*(s: IntSet): int {.inline.} =
  97. ## Iterates over any included element of `s`.
  98. if s.elems <= s.a.len:
  99. for i in 0..<s.elems:
  100. yield s.a[i]
  101. else:
  102. var r = s.head
  103. while r != nil:
  104. var i = 0
  105. while i <= high(r.bits):
  106. var w = r.bits[i]
  107. # taking a copy of r.bits[i] here is correct, because
  108. # modifying operations are not allowed during traversation
  109. var j = 0
  110. while w != 0: # test all remaining bits for zero
  111. if (w and 1) != 0: # the bit is set!
  112. yield (r.key shl TrunkShift) or (i shl IntShift +% j)
  113. inc(j)
  114. w = w shr 1
  115. inc(i)
  116. r = r.next
  117. proc bitincl(s: var IntSet, key: int) {.inline.} =
  118. var t = intSetPut(s, `shr`(key, TrunkShift))
  119. var u = key and TrunkMask
  120. t.bits[`shr`(u, IntShift)] = t.bits[`shr`(u, IntShift)] or
  121. `shl`(1, u and IntMask)
  122. proc incl*(s: var IntSet, key: int) =
  123. ## Includes an element `key` in `s`.
  124. if s.elems <= s.a.len:
  125. for i in 0..<s.elems:
  126. if s.a[i] == key: return
  127. if s.elems < s.a.len:
  128. s.a[s.elems] = key
  129. inc s.elems
  130. return
  131. newSeq(s.data, InitIntSetSize)
  132. s.max = InitIntSetSize-1
  133. for i in 0..<s.elems:
  134. bitincl(s, s.a[i])
  135. s.elems = s.a.len + 1
  136. # fall through:
  137. bitincl(s, key)
  138. proc incl*(s: var IntSet, other: IntSet) =
  139. ## Includes all elements from `other` into `s`.
  140. for item in other: incl(s, item)
  141. proc exclImpl(s: var IntSet, key: int) =
  142. if s.elems <= s.a.len:
  143. for i in 0..<s.elems:
  144. if s.a[i] == key:
  145. s.a[i] = s.a[s.elems-1]
  146. dec s.elems
  147. return
  148. else:
  149. var t = intSetGet(s, `shr`(key, TrunkShift))
  150. if t != nil:
  151. var u = key and TrunkMask
  152. t.bits[`shr`(u, IntShift)] = t.bits[`shr`(u, IntShift)] and
  153. not `shl`(1, u and IntMask)
  154. proc excl*(s: var IntSet, key: int) =
  155. ## Excludes `key` from the set `s`.
  156. exclImpl(s, key)
  157. proc excl*(s: var IntSet, other: IntSet) =
  158. ## Excludes all elements from `other` from `s`.
  159. for item in other: excl(s, item)
  160. proc missingOrExcl*(s: var IntSet, key: int) : bool =
  161. ## Returns true if `s` does not contain `key`, otherwise
  162. ## `key` is removed from `s` and false is returned.
  163. var count = s.elems
  164. exclImpl(s, key)
  165. result = count == s.elems
  166. proc containsOrIncl*(s: var IntSet, key: int): bool =
  167. ## Returns true if `s` contains `key`, otherwise `key` is included in `s`
  168. ## and false is returned.
  169. if s.elems <= s.a.len:
  170. for i in 0..<s.elems:
  171. if s.a[i] == key:
  172. return true
  173. incl(s, key)
  174. result = false
  175. else:
  176. var t = intSetGet(s, `shr`(key, TrunkShift))
  177. if t != nil:
  178. var u = key and TrunkMask
  179. result = (t.bits[`shr`(u, IntShift)] and `shl`(1, u and IntMask)) != 0
  180. if not result:
  181. t.bits[`shr`(u, IntShift)] = t.bits[`shr`(u, IntShift)] or
  182. `shl`(1, u and IntMask)
  183. else:
  184. incl(s, key)
  185. result = false
  186. proc initIntSet*: IntSet =
  187. ## Returns an empty IntSet. Example:
  188. ##
  189. ## .. code-block ::
  190. ## var a = initIntSet()
  191. ## a.incl(2)
  192. # newSeq(result.data, InitIntSetSize)
  193. # result.max = InitIntSetSize-1
  194. result = IntSet(
  195. elems: 0,
  196. counter: 0,
  197. max: 0,
  198. head: nil,
  199. data: when defined(nimNoNilSeqs): @[] else: nil)
  200. # a: array[0..33, int] # profiling shows that 34 elements are enough
  201. proc clear*(result: var IntSet) =
  202. ## Clears the IntSet back to an empty state.
  203. # setLen(result.data, InitIntSetSize)
  204. # for i in 0..InitIntSetSize-1: result.data[i] = nil
  205. # result.max = InitIntSetSize-1
  206. when defined(nimNoNilSeqs):
  207. result.data = @[]
  208. else:
  209. result.data = nil
  210. result.max = 0
  211. result.counter = 0
  212. result.head = nil
  213. result.elems = 0
  214. proc isNil*(x: IntSet): bool {.inline.} = x.head.isNil and x.elems == 0
  215. proc assign*(dest: var IntSet, src: IntSet) =
  216. ## copies `src` to `dest`. `dest` does not need to be initialized by
  217. ## `initIntSet`.
  218. if src.elems <= src.a.len:
  219. when defined(nimNoNilSeqs):
  220. dest.data = @[]
  221. else:
  222. dest.data = nil
  223. dest.max = 0
  224. dest.counter = src.counter
  225. dest.head = nil
  226. dest.elems = src.elems
  227. dest.a = src.a
  228. else:
  229. dest.counter = src.counter
  230. dest.max = src.max
  231. newSeq(dest.data, src.data.len)
  232. var it = src.head
  233. while it != nil:
  234. var h = it.key and dest.max
  235. while dest.data[h] != nil: h = nextTry(h, dest.max)
  236. assert(dest.data[h] == nil)
  237. var n: PTrunk
  238. new(n)
  239. n.next = dest.head
  240. n.key = it.key
  241. n.bits = it.bits
  242. dest.head = n
  243. dest.data[h] = n
  244. it = it.next
  245. proc union*(s1, s2: IntSet): IntSet =
  246. ## Returns the union of the sets `s1` and `s2`.
  247. result.assign(s1)
  248. incl(result, s2)
  249. proc intersection*(s1, s2: IntSet): IntSet =
  250. ## Returns the intersection of the sets `s1` and `s2`.
  251. result = initIntSet()
  252. for item in s1:
  253. if contains(s2, item):
  254. incl(result, item)
  255. proc difference*(s1, s2: IntSet): IntSet =
  256. ## Returns the difference of the sets `s1` and `s2`.
  257. result = initIntSet()
  258. for item in s1:
  259. if not contains(s2, item):
  260. incl(result, item)
  261. proc symmetricDifference*(s1, s2: IntSet): IntSet =
  262. ## Returns the symmetric difference of the sets `s1` and `s2`.
  263. result.assign(s1)
  264. for item in s2:
  265. if containsOrIncl(result, item): excl(result, item)
  266. proc `+`*(s1, s2: IntSet): IntSet {.inline.} =
  267. ## Alias for `union(s1, s2) <#union>`_.
  268. result = union(s1, s2)
  269. proc `*`*(s1, s2: IntSet): IntSet {.inline.} =
  270. ## Alias for `intersection(s1, s2) <#intersection>`_.
  271. result = intersection(s1, s2)
  272. proc `-`*(s1, s2: IntSet): IntSet {.inline.} =
  273. ## Alias for `difference(s1, s2) <#difference>`_.
  274. result = difference(s1, s2)
  275. proc disjoint*(s1, s2: IntSet): bool =
  276. ## Returns true if the sets `s1` and `s2` have no items in common.
  277. for item in s1:
  278. if contains(s2, item):
  279. return false
  280. return true
  281. proc len*(s: IntSet): int {.inline.} =
  282. ## Returns the number of keys in `s`.
  283. if s.elems < s.a.len:
  284. result = s.elems
  285. else:
  286. result = 0
  287. for _ in s:
  288. inc(result)
  289. proc card*(s: IntSet): int {.inline.} =
  290. ## Alias for `len() <#len>` _.
  291. result = s.len()
  292. proc `<=`*(s1, s2: IntSet): bool =
  293. ## Returns true iff `s1` is subset of `s2`.
  294. for item in s1:
  295. if not s2.contains(item):
  296. return false
  297. return true
  298. proc `<`*(s1, s2: IntSet): bool =
  299. ## Returns true iff `s1` is proper subset of `s2`.
  300. return s1 <= s2 and not (s2 <= s1)
  301. proc `==`*(s1, s2: IntSet): bool =
  302. ## Returns true if both `s` and `t` have the same members and set size.
  303. return s1 <= s2 and s2 <= s1
  304. template dollarImpl(): untyped =
  305. result = "{"
  306. for key in items(s):
  307. if result.len > 1: result.add(", ")
  308. result.add($key)
  309. result.add("}")
  310. proc `$`*(s: IntSet): string =
  311. ## The `$` operator for int sets.
  312. dollarImpl()
  313. proc empty*(s: IntSet): bool {.inline, deprecated.} =
  314. ## Returns true if `s` is empty. This is safe to call even before
  315. ## the set has been initialized with `initIntSet`. Note this never
  316. ## worked reliably and so is deprecated.
  317. result = s.counter == 0
  318. when isMainModule:
  319. import sequtils, algorithm
  320. var x = initIntSet()
  321. x.incl(1)
  322. x.incl(2)
  323. x.incl(7)
  324. x.incl(1056)
  325. x.incl(1044)
  326. x.excl(1044)
  327. assert x.containsOrIncl(888) == false
  328. assert 888 in x
  329. assert x.containsOrIncl(888) == true
  330. assert x.missingOrExcl(888) == false
  331. assert 888 notin x
  332. assert x.missingOrExcl(888) == true
  333. var xs = toSeq(items(x))
  334. xs.sort(cmp[int])
  335. assert xs == @[1, 2, 7, 1056]
  336. var y: IntSet
  337. assign(y, x)
  338. var ys = toSeq(items(y))
  339. ys.sort(cmp[int])
  340. assert ys == @[1, 2, 7, 1056]
  341. assert x == y
  342. var z: IntSet
  343. for i in 0..1000:
  344. incl z, i
  345. assert z.len() == i+1
  346. for i in 0..1000:
  347. assert z.contains(i)
  348. var w = initIntSet()
  349. w.incl(1)
  350. w.incl(4)
  351. w.incl(50)
  352. w.incl(1001)
  353. w.incl(1056)
  354. var xuw = x.union(w)
  355. var xuws = toSeq(items(xuw))
  356. xuws.sort(cmp[int])
  357. assert xuws == @[1, 2, 4, 7, 50, 1001, 1056]
  358. var xiw = x.intersection(w)
  359. var xiws = toSeq(items(xiw))
  360. xiws.sort(cmp[int])
  361. assert xiws == @[1, 1056]
  362. var xdw = x.difference(w)
  363. var xdws = toSeq(items(xdw))
  364. xdws.sort(cmp[int])
  365. assert xdws == @[2, 7]
  366. var xsw = x.symmetricDifference(w)
  367. var xsws = toSeq(items(xsw))
  368. xsws.sort(cmp[int])
  369. assert xsws == @[2, 4, 7, 50, 1001]
  370. x.incl(w)
  371. xs = toSeq(items(x))
  372. xs.sort(cmp[int])
  373. assert xs == @[1, 2, 4, 7, 50, 1001, 1056]
  374. assert w <= x
  375. assert w < x
  376. assert(not disjoint(w, x))
  377. var u = initIntSet()
  378. u.incl(3)
  379. u.incl(5)
  380. u.incl(500)
  381. assert disjoint(u, x)
  382. var v = initIntSet()
  383. v.incl(2)
  384. v.incl(50)
  385. x.excl(v)
  386. xs = toSeq(items(x))
  387. xs.sort(cmp[int])
  388. assert xs == @[1, 4, 7, 1001, 1056]