set.lua 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. -- a set class for fast union/diff, can always return a table with the lines
  2. -- in the same relative order in which they were added by calling the
  3. -- to_table method. It does this by keeping two lua tables that mirror each
  4. -- other:
  5. -- 1) index => item
  6. -- 2) item => index
  7. --- @class Set
  8. --- @field nelem integer
  9. --- @field items string[]
  10. --- @field tbl table
  11. local Set = {}
  12. --- @param items? string[]
  13. function Set:new(items)
  14. local obj = {} --- @type Set
  15. setmetatable(obj, self)
  16. self.__index = self
  17. if type(items) == 'table' then
  18. local tempset = Set:new()
  19. tempset:union_table(items)
  20. obj.tbl = tempset:raw_tbl()
  21. obj.items = tempset:raw_items()
  22. obj.nelem = tempset:size()
  23. else
  24. obj.tbl = {}
  25. obj.items = {}
  26. obj.nelem = 0
  27. end
  28. return obj
  29. end
  30. --- @return Set
  31. function Set:copy()
  32. local obj = { nelem = self.nelem, tbl = {}, items = {} } --- @type Set
  33. for k, v in pairs(self.tbl) do
  34. obj.tbl[k] = v
  35. end
  36. for k, v in pairs(self.items) do
  37. obj.items[k] = v
  38. end
  39. setmetatable(obj, Set)
  40. obj.__index = Set
  41. return obj
  42. end
  43. -- adds the argument Set to this Set
  44. --- @param other Set
  45. function Set:union(other)
  46. for e in other:iterator() do
  47. self:add(e)
  48. end
  49. end
  50. -- adds the argument table to this Set
  51. function Set:union_table(t)
  52. for _, v in pairs(t) do
  53. self:add(v)
  54. end
  55. end
  56. -- subtracts the argument Set from this Set
  57. --- @param other Set
  58. function Set:diff(other)
  59. if other:size() > self:size() then
  60. -- this set is smaller than the other set
  61. for e in self:iterator() do
  62. if other:contains(e) then
  63. self:remove(e)
  64. end
  65. end
  66. else
  67. -- this set is larger than the other set
  68. for e in other:iterator() do
  69. if self.items[e] then
  70. self:remove(e)
  71. end
  72. end
  73. end
  74. end
  75. --- @param it string
  76. function Set:add(it)
  77. if not self:contains(it) then
  78. local idx = #self.tbl + 1
  79. self.tbl[idx] = it
  80. self.items[it] = idx
  81. self.nelem = self.nelem + 1
  82. end
  83. end
  84. --- @param it string
  85. function Set:remove(it)
  86. if self:contains(it) then
  87. local idx = self.items[it]
  88. self.tbl[idx] = nil
  89. self.items[it] = nil
  90. self.nelem = self.nelem - 1
  91. end
  92. end
  93. --- @param it string
  94. --- @return boolean
  95. function Set:contains(it)
  96. return self.items[it] or false
  97. end
  98. --- @return integer
  99. function Set:size()
  100. return self.nelem
  101. end
  102. function Set:raw_tbl()
  103. return self.tbl
  104. end
  105. function Set:raw_items()
  106. return self.items
  107. end
  108. function Set:iterator()
  109. return pairs(self.items)
  110. end
  111. --- @return string[]
  112. function Set:to_table()
  113. -- there might be gaps in @tbl, so we have to be careful and sort first
  114. local keys = {} --- @type string[]
  115. for idx, _ in pairs(self.tbl) do
  116. keys[#keys + 1] = idx
  117. end
  118. table.sort(keys)
  119. local copy = {} --- @type string[]
  120. for _, idx in ipairs(keys) do
  121. copy[#copy + 1] = self.tbl[idx]
  122. end
  123. return copy
  124. end
  125. return Set