BS_thread_pool_light.hpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. #pragma once
  2. /**
  3. * @file BS_thread_pool_light.hpp
  4. * @author Barak Shoshany (baraksh@gmail.com) (http://baraksh.com)
  5. * @version 3.5.0
  6. * @date 2023-05-25
  7. * @copyright Copyright (c) 2023 Barak Shoshany. Licensed under the MIT license. If you found this project useful, please consider starring it on GitHub! If you use this library in software of any kind, please provide a link to the GitHub repository https://github.com/bshoshany/thread-pool in the source code and documentation. If you use this library in published research, please cite it as follows: Barak Shoshany, "A C++17 Thread Pool for High-Performance Scientific Computing", doi:10.5281/zenodo.4742687, arXiv:2105.00613 (May 2021)
  8. *
  9. * @brief BS::thread_pool_light: a fast, lightweight, and easy-to-use C++17 thread pool library. This header file contains a light version of the main library, for use when advanced features are not needed.
  10. */
  11. #define BS_THREAD_POOL_LIGHT_VERSION "v3.5.0 (2023-05-25)"
  12. #include <condition_variable> // std::condition_variable
  13. #include <exception> // std::current_exception
  14. #include <functional> // std::bind, std::function, std::invoke
  15. #include <future> // std::future, std::promise
  16. #include <memory> // std::make_shared, std::make_unique, std::shared_ptr, std::unique_ptr
  17. #include <mutex> // std::mutex, std::scoped_lock, std::unique_lock
  18. #include <queue> // std::queue
  19. #include <thread> // std::thread
  20. #include <type_traits> // std::common_type_t, std::decay_t, std::invoke_result_t, std::is_void_v
  21. #include <utility> // std::forward, std::move, std::swap
  22. namespace BS
  23. {
  24. /**
  25. * @brief A convenient shorthand for the type of std::thread::hardware_concurrency(). Should evaluate to unsigned int.
  26. */
  27. using concurrency_t = std::invoke_result_t<decltype(std::thread::hardware_concurrency)>;
  28. /**
  29. * @brief A fast, lightweight, and easy-to-use C++17 thread pool class. This is a lighter version of the main thread pool class.
  30. */
  31. class [[nodiscard]] thread_pool_light
  32. {
  33. public:
  34. // ============================
  35. // Constructors and destructors
  36. // ============================
  37. /**
  38. * @brief Construct a new thread pool.
  39. *
  40. * @param thread_count_ The number of threads to use. The default value is the total number of hardware threads available, as reported by the implementation. This is usually determined by the number of cores in the CPU. If a core is hyperthreaded, it will count as two threads.
  41. */
  42. thread_pool_light(const concurrency_t thread_count_ = 0) : thread_count(determine_thread_count(thread_count_)), threads(std::make_unique<std::thread[]>(determine_thread_count(thread_count_)))
  43. {
  44. create_threads();
  45. }
  46. /**
  47. * @brief Destruct the thread pool. Waits for all tasks to complete, then destroys all threads.
  48. */
  49. ~thread_pool_light()
  50. {
  51. wait_for_tasks();
  52. destroy_threads();
  53. }
  54. // =======================
  55. // Public member functions
  56. // =======================
  57. /**
  58. * @brief Get the number of threads in the pool.
  59. *
  60. * @return The number of threads.
  61. */
  62. [[nodiscard]] concurrency_t get_thread_count() const
  63. {
  64. return thread_count;
  65. }
  66. /**
  67. * @brief Parallelize a loop by automatically splitting it into blocks and submitting each block separately to the queue. The user must use wait_for_tasks() or some other method to ensure that the loop finishes executing, otherwise bad things will happen.
  68. *
  69. * @tparam F The type of the function to loop through.
  70. * @tparam T1 The type of the first index in the loop. Should be a signed or unsigned integer.
  71. * @tparam T2 The type of the index after the last index in the loop. Should be a signed or unsigned integer. If T1 is not the same as T2, a common type will be automatically inferred.
  72. * @tparam T The common type of T1 and T2.
  73. * @param first_index The first index in the loop.
  74. * @param index_after_last The index after the last index in the loop. The loop will iterate from first_index to (index_after_last - 1) inclusive. In other words, it will be equivalent to "for (T i = first_index; i < index_after_last; ++i)". Note that if index_after_last == first_index, no blocks will be submitted.
  75. * @param loop The function to loop through. Will be called once per block. Should take exactly two arguments: the first index in the block and the index after the last index in the block. loop(start, end) should typically involve a loop of the form "for (T i = start; i < end; ++i)".
  76. * @param num_blocks The maximum number of blocks to split the loop into. The default is to use the number of threads in the pool.
  77. */
  78. template <typename F, typename T1, typename T2, typename T = std::common_type_t<T1, T2>>
  79. void push_loop(T1 first_index_, T2 index_after_last_, F&& loop, size_t num_blocks = 0)
  80. {
  81. T first_index = static_cast<T>(first_index_);
  82. T index_after_last = static_cast<T>(index_after_last_);
  83. if (num_blocks == 0)
  84. num_blocks = thread_count;
  85. if (index_after_last < first_index)
  86. std::swap(index_after_last, first_index);
  87. size_t total_size = static_cast<size_t>(index_after_last - first_index);
  88. size_t block_size = static_cast<size_t>(total_size / num_blocks);
  89. if (block_size == 0)
  90. {
  91. block_size = 1;
  92. num_blocks = (total_size > 1) ? total_size : 1;
  93. }
  94. if (total_size > 0)
  95. {
  96. for (size_t i = 0; i < num_blocks; ++i)
  97. push_task(std::forward<F>(loop), static_cast<T>(i * block_size) + first_index, (i == num_blocks - 1) ? index_after_last : (static_cast<T>((i + 1) * block_size) + first_index));
  98. }
  99. }
  100. /**
  101. * @brief Parallelize a loop by automatically splitting it into blocks and submitting each block separately to the queue. The user must use wait_for_tasks() or some other method to ensure that the loop finishes executing, otherwise bad things will happen. This overload is used for the special case where the first index is 0.
  102. *
  103. * @tparam F The type of the function to loop through.
  104. * @tparam T The type of the loop indices. Should be a signed or unsigned integer.
  105. * @param index_after_last The index after the last index in the loop. The loop will iterate from 0 to (index_after_last - 1) inclusive. In other words, it will be equivalent to "for (T i = 0; i < index_after_last; ++i)". Note that if index_after_last == 0, no blocks will be submitted.
  106. * @param loop The function to loop through. Will be called once per block. Should take exactly two arguments: the first index in the block and the index after the last index in the block. loop(start, end) should typically involve a loop of the form "for (T i = start; i < end; ++i)".
  107. * @param num_blocks The maximum number of blocks to split the loop into. The default is to use the number of threads in the pool.
  108. */
  109. template <typename F, typename T>
  110. void push_loop(const T index_after_last, F&& loop, const size_t num_blocks = 0)
  111. {
  112. push_loop(0, index_after_last, std::forward<F>(loop), num_blocks);
  113. }
  114. /**
  115. * @brief Push a function with zero or more arguments, but no return value, into the task queue. Does not return a future, so the user must use wait_for_tasks() or some other method to ensure that the task finishes executing, otherwise bad things will happen.
  116. *
  117. * @tparam F The type of the function.
  118. * @tparam A The types of the arguments.
  119. * @param task The function to push.
  120. * @param args The zero or more arguments to pass to the function. Note that if the task is a class member function, the first argument must be a pointer to the object, i.e. &object (or this), followed by the actual arguments.
  121. */
  122. template <typename F, typename... A>
  123. void push_task(F&& task, A&&... args)
  124. {
  125. {
  126. const std::scoped_lock tasks_lock(tasks_mutex);
  127. tasks.push(std::bind(std::forward<F>(task), std::forward<A>(args)...)); // cppcheck-suppress ignoredReturnValue
  128. }
  129. task_available_cv.notify_one();
  130. }
  131. /**
  132. * @brief Submit a function with zero or more arguments into the task queue. If the function has a return value, get a future for the eventual returned value. If the function has no return value, get an std::future<void> which can be used to wait until the task finishes.
  133. *
  134. * @tparam F The type of the function.
  135. * @tparam A The types of the zero or more arguments to pass to the function.
  136. * @tparam R The return type of the function (can be void).
  137. * @param task The function to submit.
  138. * @param args The zero or more arguments to pass to the function. Note that if the task is a class member function, the first argument must be a pointer to the object, i.e. &object (or this), followed by the actual arguments.
  139. * @return A future to be used later to wait for the function to finish executing and/or obtain its returned value if it has one.
  140. */
  141. template <typename F, typename... A, typename R = std::invoke_result_t<std::decay_t<F>, std::decay_t<A>...>>
  142. [[nodiscard]] std::future<R> submit(F&& task, A&&... args)
  143. {
  144. std::shared_ptr<std::promise<R>> task_promise = std::make_shared<std::promise<R>>();
  145. push_task(
  146. [task_function = std::bind(std::forward<F>(task), std::forward<A>(args)...), task_promise]
  147. {
  148. try
  149. {
  150. if constexpr (std::is_void_v<R>)
  151. {
  152. std::invoke(task_function);
  153. task_promise->set_value();
  154. }
  155. else
  156. {
  157. task_promise->set_value(std::invoke(task_function));
  158. }
  159. }
  160. catch (...)
  161. {
  162. try
  163. {
  164. task_promise->set_exception(std::current_exception());
  165. }
  166. catch (...)
  167. {
  168. }
  169. }
  170. });
  171. return task_promise->get_future();
  172. }
  173. /**
  174. * @brief Wait for tasks to be completed, both those that are currently running in the threads and those that are still waiting in the queue. Note: To wait for just one specific task, use submit() instead, and call the wait() member function of the generated future.
  175. */
  176. void wait_for_tasks()
  177. {
  178. std::unique_lock tasks_lock(tasks_mutex);
  179. waiting = true;
  180. tasks_done_cv.wait(tasks_lock, [this] { return !tasks_running && tasks.empty(); });
  181. waiting = false;
  182. }
  183. private:
  184. // ========================
  185. // Private member functions
  186. // ========================
  187. /**
  188. * @brief Create the threads in the pool and assign a worker to each thread.
  189. */
  190. void create_threads()
  191. {
  192. {
  193. const std::scoped_lock tasks_lock(tasks_mutex);
  194. workers_running = true;
  195. }
  196. for (concurrency_t i = 0; i < thread_count; ++i)
  197. {
  198. threads[i] = std::thread(&thread_pool_light::worker, this);
  199. }
  200. }
  201. /**
  202. * @brief Destroy the threads in the pool.
  203. */
  204. void destroy_threads()
  205. {
  206. {
  207. const std::scoped_lock tasks_lock(tasks_mutex);
  208. workers_running = false;
  209. }
  210. task_available_cv.notify_all();
  211. for (concurrency_t i = 0; i < thread_count; ++i)
  212. {
  213. threads[i].join();
  214. }
  215. }
  216. /**
  217. * @brief Determine how many threads the pool should have, based on the parameter passed to the constructor.
  218. *
  219. * @param thread_count_ The parameter passed to the constructor. If the parameter is a positive number, then the pool will be created with this number of threads. If the parameter is non-positive, or a parameter was not supplied (in which case it will have the default value of 0), then the pool will be created with the total number of hardware threads available, as obtained from std::thread::hardware_concurrency(). If the latter returns a non-positive number for some reason, then the pool will be created with just one thread.
  220. * @return The number of threads to use for constructing the pool.
  221. */
  222. [[nodiscard]] concurrency_t determine_thread_count(const concurrency_t thread_count_) const
  223. {
  224. if (thread_count_ > 0)
  225. return thread_count_;
  226. else
  227. {
  228. if (std::thread::hardware_concurrency() > 0)
  229. return std::thread::hardware_concurrency();
  230. else
  231. return 1;
  232. }
  233. }
  234. /**
  235. * @brief A worker function to be assigned to each thread in the pool. Waits until it is notified by push_task() that a task is available, and then retrieves the task from the queue and executes it. Once the task finishes, the worker notifies wait_for_tasks() in case it is waiting.
  236. */
  237. void worker()
  238. {
  239. std::function<void()> task;
  240. while (true)
  241. {
  242. std::unique_lock tasks_lock(tasks_mutex);
  243. task_available_cv.wait(tasks_lock, [this] { return !tasks.empty() || !workers_running; });
  244. if (!workers_running)
  245. break;
  246. task = std::move(tasks.front());
  247. tasks.pop();
  248. ++tasks_running;
  249. tasks_lock.unlock();
  250. task();
  251. tasks_lock.lock();
  252. --tasks_running;
  253. if (waiting && !tasks_running && tasks.empty())
  254. tasks_done_cv.notify_all();
  255. }
  256. }
  257. // ============
  258. // Private data
  259. // ============
  260. /**
  261. * @brief A condition variable to notify worker() that a new task has become available.
  262. */
  263. std::condition_variable task_available_cv = {};
  264. /**
  265. * @brief A condition variable to notify wait_for_tasks() that the tasks are done.
  266. */
  267. std::condition_variable tasks_done_cv = {};
  268. /**
  269. * @brief A queue of tasks to be executed by the threads.
  270. */
  271. std::queue<std::function<void()>> tasks = {};
  272. /**
  273. * @brief A counter for the total number of currently running tasks.
  274. */
  275. size_t tasks_running = 0;
  276. /**
  277. * @brief A mutex to synchronize access to the task queue by different threads.
  278. */
  279. mutable std::mutex tasks_mutex = {};
  280. /**
  281. * @brief The number of threads in the pool.
  282. */
  283. concurrency_t thread_count = 0;
  284. /**
  285. * @brief A smart pointer to manage the memory allocated for the threads.
  286. */
  287. std::unique_ptr<std::thread[]> threads = nullptr;
  288. /**
  289. * @brief A flag indicating that wait_for_tasks() is active and expects to be notified whenever a task is done.
  290. */
  291. bool waiting = false;
  292. /**
  293. * @brief A flag indicating to the workers to keep running. When set to false, the workers terminate permanently.
  294. */
  295. bool workers_running = false;
  296. };
  297. } // namespace BS