after.lua 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. -- This is an implementation of a job sheduling mechanism. It guarantees that
  2. -- coexisting jobs will execute primarily in order of least expiry, and
  3. -- secondarily in order of first registration.
  4. -- These functions implement an intrusive singly linked list of one or more
  5. -- elements where the first element has a pointer to the last. The next pointer
  6. -- is stored with key list_next. The pointer to the last is with key list_end.
  7. local function list_init(first)
  8. first.list_end = first
  9. end
  10. local function list_append(first, append)
  11. first.list_end.list_next = append
  12. first.list_end = append
  13. end
  14. local function list_append_list(first, first_append)
  15. first.list_end.list_next = first_append
  16. first.list_end = first_append.list_end
  17. end
  18. -- The jobs are stored in a map from expiration times to linked lists of jobs
  19. -- as above. The expiration times are also stored in an array representing a
  20. -- binary min heap, which is a particular arrangement of binary tree. A parent
  21. -- at index i has children at indices i*2 and i*2+1. Out-of-bounds indices
  22. -- represent nonexistent children. A parent is never greater than its children.
  23. -- This structure means that, if there is at least one job, the next expiration
  24. -- time is the first item in the array.
  25. -- Push element on a binary min-heap,
  26. -- "bubbling up" the element by swapping with larger parents.
  27. local function heap_push(heap, element)
  28. local index = #heap + 1
  29. while index > 1 do
  30. local parent_index = math.floor(index / 2)
  31. local parent = heap[parent_index]
  32. if element < parent then
  33. heap[index] = parent
  34. index = parent_index
  35. else
  36. break
  37. end
  38. end
  39. heap[index] = element
  40. end
  41. -- Pop smallest element from the heap,
  42. -- "sinking down" the last leaf on the last layer of the heap
  43. -- by swapping with the smaller child.
  44. local function heap_pop(heap)
  45. local removed_element = heap[1]
  46. local length = #heap
  47. local element = heap[length]
  48. heap[length] = nil
  49. length = length - 1
  50. if length > 0 then
  51. local index = 1
  52. while true do
  53. local old_index = index
  54. local smaller_element = element
  55. local left_index = index * 2
  56. local right_index = index * 2 + 1
  57. if left_index <= length then
  58. local left_element = heap[left_index]
  59. if left_element < smaller_element then
  60. index = left_index
  61. smaller_element = left_element
  62. end
  63. end
  64. if right_index <= length then
  65. if heap[right_index] < smaller_element then
  66. index = right_index
  67. end
  68. end
  69. if old_index ~= index then
  70. heap[old_index] = heap[index]
  71. else
  72. break
  73. end
  74. end
  75. heap[index] = element
  76. end
  77. return removed_element
  78. end
  79. local job_map = {}
  80. local expiries = {}
  81. -- Adds an individual job with the given expiry.
  82. -- The worst-case complexity is O(log n), where n is the number of distinct
  83. -- expiration times.
  84. local function add_job(expiry, job)
  85. local list = job_map[expiry]
  86. if list then
  87. list_append(list, job)
  88. else
  89. list_init(job)
  90. job_map[expiry] = job
  91. heap_push(expiries, expiry)
  92. end
  93. end
  94. -- Removes the next expiring jobs and returns the linked list of them.
  95. -- The worst-case complexity is O(log n), where n is the number of distinct
  96. -- expiration times.
  97. local function remove_first_jobs()
  98. local removed_expiry = heap_pop(expiries)
  99. local removed = job_map[removed_expiry]
  100. job_map[removed_expiry] = nil
  101. return removed
  102. end
  103. local time = 0.0
  104. local time_next = math.huge
  105. core.register_globalstep(function(dtime)
  106. time = time + dtime
  107. if time < time_next then
  108. return
  109. end
  110. -- Remove the expired jobs.
  111. local expired = remove_first_jobs()
  112. -- Remove other expired jobs and append them to the list.
  113. while true do
  114. time_next = expiries[1] or math.huge
  115. if time_next > time then
  116. break
  117. end
  118. list_append_list(expired, remove_first_jobs())
  119. end
  120. -- Run the callbacks afterward to prevent infinite loops with core.after(0, ...).
  121. local last_expired = expired.list_end
  122. while true do
  123. core.set_last_run_mod(expired.mod_origin)
  124. expired.func(unpack(expired.args, 1, expired.args.n))
  125. if expired == last_expired then
  126. break
  127. end
  128. expired = expired.list_next
  129. end
  130. end)
  131. local job_metatable = {__index = {}}
  132. local function dummy_func() end
  133. function job_metatable.__index:cancel()
  134. self.func = dummy_func
  135. self.args = {n = 0}
  136. end
  137. function core.after(after, func, ...)
  138. assert(tonumber(after) and not core.is_nan(after) and type(func) == "function",
  139. "Invalid core.after invocation")
  140. local new_job = {
  141. mod_origin = core.get_last_run_mod(),
  142. func = func,
  143. args = {
  144. n = select("#", ...),
  145. ...
  146. },
  147. }
  148. local expiry = time + after
  149. add_job(expiry, new_job)
  150. time_next = math.min(time_next, expiry)
  151. return setmetatable(new_job, job_metatable)
  152. end