thread_pool.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. #!/usr/bin/env python3
  2. '''
  3. This file is MIT Licensed because I'm posting it on Stack Overflow:
  4. https://stackoverflow.com/questions/19369724/the-right-way-to-limit-maximum-number-of-threads-running-at-once/55263676#55263676
  5. '''
  6. from typing import Any, Callable, Dict, Iterable, Union
  7. import os
  8. import queue
  9. import sys
  10. import threading
  11. import time
  12. import traceback
  13. class ThreadPoolExitException(Exception):
  14. '''
  15. An object of this class may be raised by output_handler_function to
  16. request early termination.
  17. It is also raised by submit() if submit_raise_exit=True.
  18. '''
  19. pass
  20. class ThreadPool:
  21. '''
  22. Start a pool of a limited number of threads to do some work.
  23. This is similar to the stdlib concurrent, but I could not find
  24. how to reach all my design goals with that implementation:
  25. * the input function does not need to be modified
  26. * limit the number of threads
  27. * queue sizes closely follow number of threads
  28. * if an exception happens, optionally stop soon afterwards
  29. This class form allows to use your own while loops with submit().
  30. Exit soon after the first failure happens:
  31. ....
  32. python3 thread_pool.py 2 -10 20 handle_output_print
  33. ....
  34. Sample output:
  35. ....
  36. {'i': -9} -1.1111111111111112 None
  37. {'i': -8} -1.25 None
  38. {'i': -10} -1.0 None
  39. {'i': -6} -1.6666666666666667 None
  40. {'i': -7} -1.4285714285714286 None
  41. {'i': -4} -2.5 None
  42. {'i': -5} -2.0 None
  43. {'i': -2} -5.0 None
  44. {'i': -3} -3.3333333333333335 None
  45. {'i': 0} None ZeroDivisionError('float division by zero')
  46. {'i': -1} -10.0 None
  47. {'i': 1} 10.0 None
  48. {'i': 2} 5.0 None
  49. work_function or handle_output raised:
  50. Traceback (most recent call last):
  51. File "thread_pool.py", line 181, in _func_runner
  52. work_function_return = self.work_function(**work_function_input)
  53. File "thread_pool.py", line 281, in work_function_maybe_raise
  54. return 10.0 / i
  55. ZeroDivisionError: float division by zero
  56. work_function_input: {'i': 0}
  57. work_function_return: None
  58. ....
  59. Don't exit after first failure, run until end:
  60. ....
  61. python3 thread_pool.py 2 -10 20 handle_output_print_no_exit
  62. ....
  63. Store results in a queue for later inspection instead of printing immediately,
  64. then print everything at the end:
  65. ....
  66. python3 thread_pool.py 2 -10 20 handle_output_queue
  67. ....
  68. Exit soon after the handle_output raise.
  69. ....
  70. python3 thread_pool.py 2 -10 20 handle_output_raise
  71. ....
  72. Relying on this interface to abort execution is discouraged, this should
  73. usually only happen due to a programming error in the handler.
  74. Test that the argument called "thread_id" is passed to work_function and printed:
  75. ....
  76. python3 thread_pool.py 2 -10 20 handle_output_print thread_id
  77. ....
  78. Test with, ThreadPoolExitException and submit_raise_exit=True, same behaviour handle_output_print
  79. except for the different exit cause report:
  80. ....
  81. python3 thread_pool.py 2 -10 20 handle_output_raise_exit_exception
  82. ....
  83. '''
  84. def __init__(
  85. self,
  86. work_function: Callable,
  87. handle_output: Union[Callable[[Any,Any,Exception],Any],None] = None,
  88. nthreads: Union[int,None] = None,
  89. thread_id_arg: Union[str,None] = None,
  90. submit_raise_exit: bool = False,
  91. submit_skip_exit: bool = False,
  92. ):
  93. '''
  94. Start in a thread pool immediately.
  95. join() must be called afterwards at some point.
  96. :param work_function: main work function to be evaluated.
  97. :param handle_output: called on work_function return values as they
  98. are returned.
  99. The function signature is:
  100. ....
  101. handle_output(
  102. work_function_input: Union[Dict,None],
  103. work_function_return,
  104. work_function_exception: Exception
  105. ) -> Union[Exception,None]
  106. ....
  107. where work_function_exception the exception that work_function raised,
  108. or None otherwise
  109. The first non-None return value of a call to this function is returned by
  110. submit(), get_handle_output_result() and join().
  111. The intended semantic for this, is to return:
  112. * on success:
  113. ** None to continue execution
  114. ** ThreadPoolExitException() to request stop execution
  115. * if work_function_input or work_function_exception raise:
  116. ** the exception raised
  117. The ThreadPool user can then optionally terminate execution early on error
  118. or request with either:
  119. * an explicit submit() return value check + break if a submit loop is used
  120. * `with` + submit_raise_exit=True
  121. Default: a handler that just returns `exception`, which can normally be used
  122. by the submit loop to detect an error and exit immediately.
  123. :param nthreads: number of threads to use. Default: nproc.
  124. :param thread_id_arg: if not None, set the argument of work_function with this name
  125. to a 0-indexed thread ID. This allows function calls to coordinate
  126. usage of external resources such as files or ports.
  127. :param submit_raise_exit: if True, submit() raises ThreadPoolExitException() if
  128. get_handle_output_result() is not None.
  129. :param submit_skip_exit: if True, submit() does nothing if
  130. get_handle_output_result() is not None.
  131. You should avoid this interface if
  132. you can use use submit_raise_exit with `with` instead ideally.
  133. However, when you can't work with with and are in a deeply nested loop,
  134. it might just be easier to set this.
  135. '''
  136. self.work_function = work_function
  137. if handle_output is None:
  138. handle_output = lambda input, output, exception: exception
  139. self.handle_output = handle_output
  140. if nthreads is None:
  141. nthreads = len(os.sched_getaffinity(0))
  142. self.thread_id_arg = thread_id_arg
  143. self.submit_raise_exit = submit_raise_exit
  144. self.submit_skip_exit = submit_skip_exit
  145. self.nthreads = nthreads
  146. self.handle_output_result = None
  147. self.handle_output_result_lock = threading.Lock()
  148. self.in_queue = queue.Queue(maxsize=nthreads)
  149. self.threads = []
  150. for i in range(self.nthreads):
  151. thread = threading.Thread(
  152. target=self._func_runner,
  153. args=(i,)
  154. )
  155. self.threads.append(thread)
  156. thread.start()
  157. def __enter__(self):
  158. '''
  159. __exit__ automatically calls join() for you.
  160. This is cool because it automatically ends the loop if an exception occurs.
  161. But don't forget that errors may happen after the last submit was called, so you
  162. likely want to check for that with get_handle_output_result() after the with.
  163. '''
  164. return self
  165. def __exit__(self, exception_type, exception_value, exception_traceback):
  166. self.join()
  167. return exception_type is ThreadPoolExitException
  168. def _func_runner(self, thread_id):
  169. while True:
  170. work_function_input = self.in_queue.get(block=True)
  171. if work_function_input is None:
  172. break
  173. if self.thread_id_arg is not None:
  174. work_function_input[self.thread_id_arg] = thread_id
  175. try:
  176. work_function_exception = None
  177. work_function_return = self.work_function(**work_function_input)
  178. except Exception as e:
  179. work_function_exception = e
  180. work_function_return = None
  181. handle_output_exception = None
  182. try:
  183. handle_output_return = self.handle_output(
  184. work_function_input,
  185. work_function_return,
  186. work_function_exception
  187. )
  188. except Exception as e:
  189. handle_output_exception = e
  190. handle_output_result = None
  191. if handle_output_exception is not None:
  192. handle_output_result = handle_output_exception
  193. elif handle_output_return is not None:
  194. handle_output_result = handle_output_return
  195. if handle_output_result is not None and self.handle_output_result is None:
  196. with self.handle_output_result_lock:
  197. self.handle_output_result = (
  198. work_function_input,
  199. work_function_return,
  200. handle_output_result
  201. )
  202. self.in_queue.task_done()
  203. @staticmethod
  204. def exception_traceback_string(exception):
  205. '''
  206. Helper to get the traceback from an exception object.
  207. This is usually what you want to print if an error happens in a thread:
  208. https://stackoverflow.com/questions/3702675/how-to-print-the-full-traceback-without-halting-the-program/56199295#56199295
  209. '''
  210. return ''.join(traceback.format_exception(
  211. None, exception, exception.__traceback__)
  212. )
  213. def get_handle_output_result(self):
  214. '''
  215. :return: if a handle_output call has raised previously, return a tuple:
  216. ....
  217. (work_function_input, work_function_return, exception_raised)
  218. ....
  219. corresponding to the first such raise.
  220. Otherwise, if a handle_output returned non-None, a tuple:
  221. (work_function_input, work_function_return, handle_output_return)
  222. Otherwise, None.
  223. '''
  224. return self.handle_output_result
  225. def join(self):
  226. '''
  227. Request all threads to stop after they finish currently submitted work.
  228. :return: same as get_handle_output_result()
  229. '''
  230. for thread in range(self.nthreads):
  231. self.in_queue.put(None)
  232. for thread in self.threads:
  233. thread.join()
  234. return self.get_handle_output_result()
  235. def submit(
  236. self,
  237. work_function_input: Union[Dict,None] =None
  238. ):
  239. '''
  240. Submit work. Block if there is already enough work scheduled (~nthreads).
  241. :return: the same as get_handle_output_result
  242. '''
  243. handle_output_result = self.get_handle_output_result()
  244. if handle_output_result is not None:
  245. if self.submit_raise_exit:
  246. raise ThreadPoolExitException()
  247. if self.submit_skip_exit:
  248. return handle_output_result
  249. if work_function_input is None:
  250. work_function_input = {}
  251. self.in_queue.put(work_function_input)
  252. return handle_output_result
  253. if __name__ == '__main__':
  254. def get_work(min_, max_):
  255. '''
  256. Generate simple range work for work_function.
  257. '''
  258. for i in range(min_, max_):
  259. yield {'i': i}
  260. def work_function_maybe_raise(i):
  261. '''
  262. The main function that will be evaluated.
  263. It sleeps to simulate an IO operation.
  264. '''
  265. time.sleep((abs(i) % 4) / 10.0)
  266. return 10.0 / i
  267. def work_function_get_thread(i, thread_id):
  268. time.sleep((abs(i) % 4) / 10.0)
  269. return thread_id
  270. def handle_output_print(input, output, exception):
  271. '''
  272. Print outputs and exit immediately on failure.
  273. '''
  274. print('{!r} {!r} {!r}'.format(input, output, exception))
  275. return exception
  276. def handle_output_print_no_exit(input, output, exception):
  277. '''
  278. Print outputs, don't exit on failure.
  279. '''
  280. print('{!r} {!r} {!r}'.format(input, output, exception))
  281. out_queue = queue.Queue()
  282. def handle_output_queue(input, output, exception):
  283. '''
  284. Store outputs in a queue for later usage.
  285. '''
  286. global out_queue
  287. out_queue.put((input, output, exception))
  288. return exception
  289. def handle_output_raise(input, output, exception):
  290. '''
  291. Raise if input == 0, to test that execution
  292. stops nicely if this raises.
  293. '''
  294. print('{!r} {!r} {!r}'.format(input, output, exception))
  295. if input['i'] == 0:
  296. raise Exception
  297. def handle_output_raise_exit_exception(input, output, exception):
  298. '''
  299. Return a ThreadPoolExitException() if input == -5.
  300. Return the work_function exception if it raised.
  301. '''
  302. print('{!r} {!r} {!r}'.format(input, output, exception))
  303. if exception:
  304. return exception
  305. if output == 10.0 / -5:
  306. return ThreadPoolExitException()
  307. # CLI arguments.
  308. argv_len = len(sys.argv)
  309. if argv_len > 1:
  310. nthreads = int(sys.argv[1])
  311. if nthreads == 0:
  312. nthreads = None
  313. else:
  314. nthreads = None
  315. if argv_len > 2:
  316. min_ = int(sys.argv[2])
  317. else:
  318. min_ = 1
  319. if argv_len > 3:
  320. max_ = int(sys.argv[3])
  321. else:
  322. max_ = 100
  323. if argv_len > 4:
  324. handle_output_funtion_string = sys.argv[4]
  325. else:
  326. handle_output_funtion_string = 'handle_output_print'
  327. handle_output = eval(handle_output_funtion_string)
  328. if argv_len > 5:
  329. work_function = work_function_get_thread
  330. thread_id_arg = sys.argv[5]
  331. else:
  332. work_function = work_function_maybe_raise
  333. thread_id_arg = None
  334. # Action.
  335. if handle_output is handle_output_raise_exit_exception:
  336. # `with` version with implicit join and submit raise
  337. # immediately when desired with ThreadPoolExitException.
  338. #
  339. # This is the more safe and convenient and DRY usage if
  340. # you can use `with`, so prefer it generally.
  341. with ThreadPool(
  342. work_function,
  343. handle_output,
  344. nthreads,
  345. thread_id_arg,
  346. submit_raise_exit=True
  347. ) as my_thread_pool:
  348. for work in get_work(min_, max_):
  349. my_thread_pool.submit(work)
  350. handle_output_result = my_thread_pool.get_handle_output_result()
  351. else:
  352. # Explicit error checking in submit loop to exit immediately
  353. # on error.
  354. my_thread_pool = ThreadPool(
  355. work_function,
  356. handle_output,
  357. nthreads,
  358. thread_id_arg,
  359. )
  360. for work_function_input in get_work(min_, max_):
  361. handle_output_result = my_thread_pool.submit(work_function_input)
  362. if handle_output_result is not None:
  363. break
  364. handle_output_result = my_thread_pool.join()
  365. if handle_output_result is not None:
  366. work_function_input, work_function_return, exception = handle_output_result
  367. if type(exception) is ThreadPoolExitException:
  368. print('Early exit requested by handle_output with ThreadPoolExitException:')
  369. else:
  370. print('work_function or handle_output raised:')
  371. print(ThreadPool.exception_traceback_string(exception), end='')
  372. print('work_function_input: {!r}'.format(work_function_input))
  373. print('work_function_return: {!r}'.format(work_function_return))
  374. if handle_output == handle_output_queue:
  375. while not out_queue.empty():
  376. print(out_queue.get())