intsets.nim 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. #
  2. #
  3. # Nim's Runtime Library
  4. # (c) Copyright 2012 Andreas Rumpf
  5. #
  6. # See the file "copying.txt", included in this
  7. # distribution, for details about the copyright.
  8. #
  9. ## The ``intsets`` module implements an efficient int set implemented as a
  10. ## `sparse bit set`:idx:.
  11. ## **Note**: Since Nim currently does not allow the assignment operator to
  12. ## be overloaded, ``=`` for int sets performs some rather meaningless shallow
  13. ## copy; use ``assign`` to get a deep copy.
  14. import
  15. hashes, math
  16. type
  17. BitScalar = int
  18. const
  19. InitIntSetSize = 8 # must be a power of two!
  20. TrunkShift = 9
  21. BitsPerTrunk = 1 shl TrunkShift # needs to be a power of 2 and
  22. # divisible by 64
  23. TrunkMask = BitsPerTrunk - 1
  24. IntsPerTrunk = BitsPerTrunk div (sizeof(BitScalar) * 8)
  25. IntShift = 5 + ord(sizeof(BitScalar) == 8) # 5 or 6, depending on int width
  26. IntMask = 1 shl IntShift - 1
  27. type
  28. PTrunk = ref Trunk
  29. Trunk = object
  30. next: PTrunk # all nodes are connected with this pointer
  31. key: int # start address at bit 0
  32. bits: array[0..IntsPerTrunk - 1, BitScalar] # a bit vector
  33. TrunkSeq = seq[PTrunk]
  34. IntSet* = object ## an efficient set of 'int' implemented as a sparse bit set
  35. elems: int # only valid for small numbers
  36. counter, max: int
  37. head: PTrunk
  38. data: TrunkSeq
  39. a: array[0..33, int] # profiling shows that 34 elements are enough
  40. {.deprecated: [TIntSet: IntSet, TTrunk: Trunk, TTrunkSeq: TrunkSeq].}
  41. proc mustRehash(length, counter: int): bool {.inline.} =
  42. assert(length > counter)
  43. result = (length * 2 < counter * 3) or (length - counter < 4)
  44. proc nextTry(h, maxHash: Hash): Hash {.inline.} =
  45. result = ((5 * h) + 1) and maxHash
  46. proc intSetGet(t: IntSet, key: int): PTrunk =
  47. var h = key and t.max
  48. while t.data[h] != nil:
  49. if t.data[h].key == key:
  50. return t.data[h]
  51. h = nextTry(h, t.max)
  52. result = nil
  53. proc intSetRawInsert(t: IntSet, data: var TrunkSeq, desc: PTrunk) =
  54. var h = desc.key and t.max
  55. while data[h] != nil:
  56. assert(data[h] != desc)
  57. h = nextTry(h, t.max)
  58. assert(data[h] == nil)
  59. data[h] = desc
  60. proc intSetEnlarge(t: var IntSet) =
  61. var n: TrunkSeq
  62. var oldMax = t.max
  63. t.max = ((t.max + 1) * 2) - 1
  64. newSeq(n, t.max + 1)
  65. for i in countup(0, oldMax):
  66. if t.data[i] != nil: intSetRawInsert(t, n, t.data[i])
  67. swap(t.data, n)
  68. proc intSetPut(t: var IntSet, key: int): PTrunk =
  69. var h = key and t.max
  70. while t.data[h] != nil:
  71. if t.data[h].key == key:
  72. return t.data[h]
  73. h = nextTry(h, t.max)
  74. if mustRehash(t.max + 1, t.counter): intSetEnlarge(t)
  75. inc(t.counter)
  76. h = key and t.max
  77. while t.data[h] != nil: h = nextTry(h, t.max)
  78. assert(t.data[h] == nil)
  79. new(result)
  80. result.next = t.head
  81. result.key = key
  82. t.head = result
  83. t.data[h] = result
  84. proc contains*(s: IntSet, key: int): bool =
  85. ## returns true iff `key` is in `s`.
  86. if s.elems <= s.a.len:
  87. for i in 0..<s.elems:
  88. if s.a[i] == key: return true
  89. else:
  90. var t = intSetGet(s, `shr`(key, TrunkShift))
  91. if t != nil:
  92. var u = key and TrunkMask
  93. result = (t.bits[`shr`(u, IntShift)] and `shl`(1, u and IntMask)) != 0
  94. else:
  95. result = false
  96. iterator items*(s: IntSet): int {.inline.} =
  97. ## iterates over any included element of `s`.
  98. if s.elems <= s.a.len:
  99. for i in 0..<s.elems:
  100. yield s.a[i]
  101. else:
  102. var r = s.head
  103. while r != nil:
  104. var i = 0
  105. while i <= high(r.bits):
  106. var w = r.bits[i]
  107. # taking a copy of r.bits[i] here is correct, because
  108. # modifying operations are not allowed during traversation
  109. var j = 0
  110. while w != 0: # test all remaining bits for zero
  111. if (w and 1) != 0: # the bit is set!
  112. yield (r.key shl TrunkShift) or (i shl IntShift +% j)
  113. inc(j)
  114. w = w shr 1
  115. inc(i)
  116. r = r.next
  117. proc bitincl(s: var IntSet, key: int) {.inline.} =
  118. var t = intSetPut(s, `shr`(key, TrunkShift))
  119. var u = key and TrunkMask
  120. t.bits[`shr`(u, IntShift)] = t.bits[`shr`(u, IntShift)] or
  121. `shl`(1, u and IntMask)
  122. proc incl*(s: var IntSet, key: int) =
  123. ## includes an element `key` in `s`.
  124. if s.elems <= s.a.len:
  125. for i in 0..<s.elems:
  126. if s.a[i] == key: return
  127. if s.elems < s.a.len:
  128. s.a[s.elems] = key
  129. inc s.elems
  130. return
  131. newSeq(s.data, InitIntSetSize)
  132. s.max = InitIntSetSize-1
  133. for i in 0..<s.elems:
  134. bitincl(s, s.a[i])
  135. s.elems = s.a.len + 1
  136. # fall through:
  137. bitincl(s, key)
  138. proc incl*(s: var IntSet, other: IntSet) =
  139. ## Includes all elements from `other` into `s`.
  140. for item in other: incl(s, item)
  141. proc exclImpl(s: var IntSet, key: int) =
  142. if s.elems <= s.a.len:
  143. for i in 0..<s.elems:
  144. if s.a[i] == key:
  145. s.a[i] = s.a[s.elems-1]
  146. dec s.elems
  147. return
  148. else:
  149. var t = intSetGet(s, `shr`(key, TrunkShift))
  150. if t != nil:
  151. var u = key and TrunkMask
  152. t.bits[`shr`(u, IntShift)] = t.bits[`shr`(u, IntShift)] and
  153. not `shl`(1, u and IntMask)
  154. proc excl*(s: var IntSet, key: int) =
  155. ## excludes `key` from the set `s`.
  156. exclImpl(s, key)
  157. proc excl*(s: var IntSet, other: IntSet) =
  158. ## Excludes all elements from `other` from `s`.
  159. for item in other: excl(s, item)
  160. proc missingOrExcl*(s: var IntSet, key: int) : bool =
  161. ## returns true if `s` does not contain `key`, otherwise
  162. ## `key` is removed from `s` and false is returned.
  163. var count = s.elems
  164. exclImpl(s, key)
  165. result = count == s.elems
  166. proc containsOrIncl*(s: var IntSet, key: int): bool =
  167. ## returns true if `s` contains `key`, otherwise `key` is included in `s`
  168. ## and false is returned.
  169. if s.elems <= s.a.len:
  170. for i in 0..<s.elems:
  171. if s.a[i] == key:
  172. return true
  173. incl(s, key)
  174. result = false
  175. else:
  176. var t = intSetGet(s, `shr`(key, TrunkShift))
  177. if t != nil:
  178. var u = key and TrunkMask
  179. result = (t.bits[`shr`(u, IntShift)] and `shl`(1, u and IntMask)) != 0
  180. if not result:
  181. t.bits[`shr`(u, IntShift)] = t.bits[`shr`(u, IntShift)] or
  182. `shl`(1, u and IntMask)
  183. else:
  184. incl(s, key)
  185. result = false
  186. proc initIntSet*: IntSet =
  187. ## creates a new int set that is empty.
  188. #newSeq(result.data, InitIntSetSize)
  189. #result.max = InitIntSetSize-1
  190. when defined(nimNoNilSeqs):
  191. result.data = @[]
  192. else:
  193. result.data = nil
  194. result.max = 0
  195. result.counter = 0
  196. result.head = nil
  197. result.elems = 0
  198. proc clear*(result: var IntSet) =
  199. #setLen(result.data, InitIntSetSize)
  200. #for i in 0..InitIntSetSize-1: result.data[i] = nil
  201. #result.max = InitIntSetSize-1
  202. when defined(nimNoNilSeqs):
  203. result.data = @[]
  204. else:
  205. result.data = nil
  206. result.max = 0
  207. result.counter = 0
  208. result.head = nil
  209. result.elems = 0
  210. proc isNil*(x: IntSet): bool {.inline.} = x.head.isNil and x.elems == 0
  211. proc assign*(dest: var IntSet, src: IntSet) =
  212. ## copies `src` to `dest`. `dest` does not need to be initialized by
  213. ## `initIntSet`.
  214. if src.elems <= src.a.len:
  215. when defined(nimNoNilSeqs):
  216. dest.data = @[]
  217. else:
  218. dest.data = nil
  219. dest.max = 0
  220. dest.counter = src.counter
  221. dest.head = nil
  222. dest.elems = src.elems
  223. dest.a = src.a
  224. else:
  225. dest.counter = src.counter
  226. dest.max = src.max
  227. newSeq(dest.data, src.data.len)
  228. var it = src.head
  229. while it != nil:
  230. var h = it.key and dest.max
  231. while dest.data[h] != nil: h = nextTry(h, dest.max)
  232. assert(dest.data[h] == nil)
  233. var n: PTrunk
  234. new(n)
  235. n.next = dest.head
  236. n.key = it.key
  237. n.bits = it.bits
  238. dest.head = n
  239. dest.data[h] = n
  240. it = it.next
  241. proc union*(s1, s2: IntSet): IntSet =
  242. ## Returns the union of the sets `s1` and `s2`.
  243. result.assign(s1)
  244. incl(result, s2)
  245. proc intersection*(s1, s2: IntSet): IntSet =
  246. ## Returns the intersection of the sets `s1` and `s2`.
  247. result = initIntSet()
  248. for item in s1:
  249. if contains(s2, item):
  250. incl(result, item)
  251. proc difference*(s1, s2: IntSet): IntSet =
  252. ## Returns the difference of the sets `s1` and `s2`.
  253. result = initIntSet()
  254. for item in s1:
  255. if not contains(s2, item):
  256. incl(result, item)
  257. proc symmetricDifference*(s1, s2: IntSet): IntSet =
  258. ## Returns the symmetric difference of the sets `s1` and `s2`.
  259. result.assign(s1)
  260. for item in s2:
  261. if containsOrIncl(result, item): excl(result, item)
  262. proc `+`*(s1, s2: IntSet): IntSet {.inline.} =
  263. ## Alias for `union(s1, s2) <#union>`_.
  264. result = union(s1, s2)
  265. proc `*`*(s1, s2: IntSet): IntSet {.inline.} =
  266. ## Alias for `intersection(s1, s2) <#intersection>`_.
  267. result = intersection(s1, s2)
  268. proc `-`*(s1, s2: IntSet): IntSet {.inline.} =
  269. ## Alias for `difference(s1, s2) <#difference>`_.
  270. result = difference(s1, s2)
  271. proc disjoint*(s1, s2: IntSet): bool =
  272. ## Returns true iff the sets `s1` and `s2` have no items in common.
  273. for item in s1:
  274. if contains(s2, item):
  275. return false
  276. return true
  277. proc len*(s: IntSet): int {.inline.} =
  278. ## Returns the number of keys in `s`.
  279. if s.elems < s.a.len:
  280. result = s.elems
  281. else:
  282. result = 0
  283. for _ in s:
  284. inc(result)
  285. proc card*(s: IntSet): int {.inline.} =
  286. ## alias for `len() <#len>` _.
  287. result = s.len()
  288. proc `<=`*(s1, s2: IntSet): bool =
  289. ## Returns true iff `s1` is subset of `s2`.
  290. for item in s1:
  291. if not s2.contains(item):
  292. return false
  293. return true
  294. proc `<`*(s1, s2: IntSet): bool =
  295. ## Returns true iff `s1` is proper subset of `s2`.
  296. return s1 <= s2 and not (s2 <= s1)
  297. proc `==`*(s1, s2: IntSet): bool =
  298. ## Returns true if both `s` and `t` have the same members and set size.
  299. return s1 <= s2 and s2 <= s1
  300. template dollarImpl(): untyped =
  301. result = "{"
  302. for key in items(s):
  303. if result.len > 1: result.add(", ")
  304. result.add($key)
  305. result.add("}")
  306. proc `$`*(s: IntSet): string =
  307. ## The `$` operator for int sets.
  308. dollarImpl()
  309. proc empty*(s: IntSet): bool {.inline, deprecated.} =
  310. ## returns true if `s` is empty. This is safe to call even before
  311. ## the set has been initialized with `initIntSet`. Note this never
  312. ## worked reliably and so is deprecated.
  313. result = s.counter == 0
  314. when isMainModule:
  315. import sequtils, algorithm
  316. var x = initIntSet()
  317. x.incl(1)
  318. x.incl(2)
  319. x.incl(7)
  320. x.incl(1056)
  321. x.incl(1044)
  322. x.excl(1044)
  323. assert x.containsOrIncl(888) == false
  324. assert 888 in x
  325. assert x.containsOrIncl(888) == true
  326. assert x.missingOrExcl(888) == false
  327. assert 888 notin x
  328. assert x.missingOrExcl(888) == true
  329. var xs = toSeq(items(x))
  330. xs.sort(cmp[int])
  331. assert xs == @[1, 2, 7, 1056]
  332. var y: IntSet
  333. assign(y, x)
  334. var ys = toSeq(items(y))
  335. ys.sort(cmp[int])
  336. assert ys == @[1, 2, 7, 1056]
  337. assert x == y
  338. var z: IntSet
  339. for i in 0..1000:
  340. incl z, i
  341. assert z.len() == i+1
  342. for i in 0..1000:
  343. assert z.contains(i)
  344. var w = initIntSet()
  345. w.incl(1)
  346. w.incl(4)
  347. w.incl(50)
  348. w.incl(1001)
  349. w.incl(1056)
  350. var xuw = x.union(w)
  351. var xuws = toSeq(items(xuw))
  352. xuws.sort(cmp[int])
  353. assert xuws == @[1, 2, 4, 7, 50, 1001, 1056]
  354. var xiw = x.intersection(w)
  355. var xiws = toSeq(items(xiw))
  356. xiws.sort(cmp[int])
  357. assert xiws == @[1, 1056]
  358. var xdw = x.difference(w)
  359. var xdws = toSeq(items(xdw))
  360. xdws.sort(cmp[int])
  361. assert xdws == @[2, 7]
  362. var xsw = x.symmetricDifference(w)
  363. var xsws = toSeq(items(xsw))
  364. xsws.sort(cmp[int])
  365. assert xsws == @[2, 4, 7, 50, 1001]
  366. x.incl(w)
  367. xs = toSeq(items(x))
  368. xs.sort(cmp[int])
  369. assert xs == @[1, 2, 4, 7, 50, 1001, 1056]
  370. assert w <= x
  371. assert w < x
  372. assert(not disjoint(w, x))
  373. var u = initIntSet()
  374. u.incl(3)
  375. u.incl(5)
  376. u.incl(500)
  377. assert disjoint(u, x)
  378. var v = initIntSet()
  379. v.incl(2)
  380. v.incl(50)
  381. x.excl(v)
  382. xs = toSeq(items(x))
  383. xs.sort(cmp[int])
  384. assert xs == @[1, 4, 7, 1001, 1056]