set.lua 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. -- a set class for fast union/diff, can always return a table with the lines
  2. -- in the same relative order in which they were added by calling the
  3. -- to_table method. It does this by keeping two lua tables that mirror each
  4. -- other:
  5. -- 1) index => item
  6. -- 2) item => index
  7. local Set = {}
  8. function Set:new(items)
  9. local obj = {}
  10. setmetatable(obj, self)
  11. self.__index = self
  12. if type(items) == 'table' then
  13. local tempset = Set:new()
  14. tempset:union_table(items)
  15. obj.tbl = tempset:raw_tbl()
  16. obj.items = tempset:raw_items()
  17. obj.nelem = tempset:size()
  18. else
  19. obj.tbl = {}
  20. obj.items = {}
  21. obj.nelem = 0
  22. end
  23. return obj
  24. end
  25. function Set:copy()
  26. local obj = {}
  27. obj.nelem = self.nelem
  28. obj.tbl = {}
  29. obj.items = {}
  30. for k, v in pairs(self.tbl) do
  31. obj.tbl[k] = v
  32. end
  33. for k, v in pairs(self.items) do
  34. obj.items[k] = v
  35. end
  36. setmetatable(obj, Set)
  37. obj.__index = Set
  38. return obj
  39. end
  40. -- adds the argument Set to this Set
  41. function Set:union(other)
  42. for e in other:iterator() do
  43. self:add(e)
  44. end
  45. end
  46. -- adds the argument table to this Set
  47. function Set:union_table(t)
  48. for _, v in pairs(t) do
  49. self:add(v)
  50. end
  51. end
  52. -- subtracts the argument Set from this Set
  53. function Set:diff(other)
  54. if other:size() > self:size() then
  55. -- this set is smaller than the other set
  56. for e in self:iterator() do
  57. if other:contains(e) then
  58. self:remove(e)
  59. end
  60. end
  61. else
  62. -- this set is larger than the other set
  63. for e in other:iterator() do
  64. if self.items[e] then
  65. self:remove(e)
  66. end
  67. end
  68. end
  69. end
  70. function Set:add(it)
  71. if not self:contains(it) then
  72. local idx = #self.tbl + 1
  73. self.tbl[idx] = it
  74. self.items[it] = idx
  75. self.nelem = self.nelem + 1
  76. end
  77. end
  78. function Set:remove(it)
  79. if self:contains(it) then
  80. local idx = self.items[it]
  81. self.tbl[idx] = nil
  82. self.items[it] = nil
  83. self.nelem = self.nelem - 1
  84. end
  85. end
  86. function Set:contains(it)
  87. return self.items[it] or false
  88. end
  89. function Set:size()
  90. return self.nelem
  91. end
  92. function Set:raw_tbl()
  93. return self.tbl
  94. end
  95. function Set:raw_items()
  96. return self.items
  97. end
  98. function Set:iterator()
  99. return pairs(self.items)
  100. end
  101. function Set:to_table()
  102. -- there might be gaps in @tbl, so we have to be careful and sort first
  103. local keys
  104. do
  105. local _accum_0 = { }
  106. local _len_0 = 1
  107. for idx, _ in pairs(self.tbl) do
  108. _accum_0[_len_0] = idx
  109. _len_0 = _len_0 + 1
  110. end
  111. keys = _accum_0
  112. end
  113. table.sort(keys)
  114. local copy
  115. do
  116. local _accum_0 = { }
  117. local _len_0 = 1
  118. for _index_0 = 1, #keys do
  119. local idx = keys[_index_0]
  120. _accum_0[_len_0] = self.tbl[idx]
  121. _len_0 = _len_0 + 1
  122. end
  123. copy = _accum_0
  124. end
  125. return copy
  126. end
  127. return Set