ThreadPool.h
1/* Distributed under the Apache License, Version 2.0.
2 See accompanying NOTICE file for details.*/
3
4#pragma once
5
6#include "cdm/CommonDefs.h"
7
8#include <thread>
9#include <mutex>
10#include <condition_variable>
11#include <future>
12#include <functional>
13#include <queue>
14#include <atomic>
15#include <condition_variable>
16
17template<typename T>
19{
20public:
21 MultiThreadedVectorProcessor(const std::vector<T*>& v) : m_Vector(v)
22 {
23 m_NextIdx.store(0);
24 m_NumComplete.store(0);
25 m_Stop = false;
26 }
28 {
29 Stop();
30 }
31
32 void Start(size_t numThreads)
33 {
34 Stop();
35 m_NextIdx.store(m_Vector.size());
36 m_NumComplete.store(0);
37 for (size_t i = 0; i < numThreads; i++)
38 m_Threads.push_back(std::thread(&MultiThreadedVectorProcessor::Run, this));
39 }
40
41 void Stop()
42 {
43 m_Stop = true;
44 for (std::thread& t : m_Threads)
45 t.join();
46 m_Threads.clear();
47 m_Stop = false;
48 }
49
51 {
52 m_NumComplete.store(0);
53 m_NextIdx.store(0);
54 // Wait for all items in vector to be processed
55 while (m_NumComplete.load(std::memory_order_consume) < m_Vector.size());
56 }
57
58 void Run()
59 {
60 while (!m_Stop)
61 {
62 if (m_NextIdx.load(std::memory_order_consume) >= m_Vector.size())
63 {
64 continue;
65 }
66 T* m = nullptr;
67 m_Mutex.lock();
68 size_t idx = m_NextIdx.load();
69 if (idx < m_Vector.size())
70 {
71 m = m_Vector[idx];
72 m_NextIdx++;
73 }
74 m_Mutex.unlock();
75 if (m != nullptr)
76 {
77 Work(m);
79 }
80 }
81 m_NumComplete.store(m_Vector.size());
82 }
83
84 virtual void Work(T*) = 0;
85protected:
86
87 std::atomic<size_t> m_NextIdx;
88 std::atomic<size_t> m_NumComplete;
89 const std::vector<T*>& m_Vector;
90
91
92 bool m_Stop;
93 std::mutex m_Mutex;
94 std::vector<std::thread> m_Threads;
95};
96
97/*
98*** assigns indices to threads instead of first-come-first-serve
99*** keeps threads alive as long as possible
100*** will need to update constructor calls with numThreads and don't need to call start anymore
101template<typename T>
102class MultiThreadedVectorProcessor
103{
104public:
105 MultiThreadedVectorProcessor(const std::vector<T*>& v, size_t numThreads = 1) : m_Vector(v)
106 {
107 m_NumThreads = numThreads;
108 m_Stop = false;
109 SplitVector();
110
111 StartThreads();
112 }
113 virtual ~MultiThreadedVectorProcessor()
114 {
115 Stop();
116 }
117
118 void Stop()
119 {
120 m_Stop = true;
121 for(std::thread& t : m_Threads)
122 t.join();
123 m_Threads.clear();
124 m_Stop = false;
125 }
126
127 void SetNumberOfThreads(size_t numThreads)
128 {
129 if(m_NumThreads != numThreads)
130 {
131 Stop();
132
133 m_NumThreads = numThreads;
134 SplitVector();
135 StartThreads();
136 }
137 }
138
139 void ProcessVectorContents()
140 {
141 // Ensure the proper number of threads are running
142 if(m_Threads.size() != m_NumThreads)
143 {
144 size_t nThreads = m_NumThreads;
145 m_NumThreads = std::numeric_limits<std::size_t>::max();
146 SetNumberOfThreads(nThreads);
147 }
148
149 // Elements have been added/removed from vector
150 if(m_Vector.size() != m_VectorSize)
151 SplitVector();
152
153 // Tell threads to process
154 m_NumComplete = 0;
155 for( auto p : m_Process )
156 {
157 p = true;
158 }
159
160 // Wait for all threads to process
161 while (m_NumComplete < m_Vector.size())
162 std::this_thread::sleep_for(std::chrono::nanoseconds(1));
163 }
164
165 void Run(size_t id)
166 {
167 while(true)
168 {
169 if(m_Stop)
170 return;
171
172 if(m_Process[id])
173 {
174 auto& indices = m_IndexAssignments[id];
175 for(size_t idx = indices.first; idx < indices.second; ++idx)
176 {
177 T* m = m_Vector[idx];
178 if( m != nullptr )
179 {
180 Work(m);
181 }
182 }
183 m_Mutex.lock();
184 m_Process[id] = false;
185 m_NumComplete += indices.second - indices.first;
186 m_Mutex.unlock();
187 }
188 else
189 {
190 std::this_thread::sleep_for(std::chrono::nanoseconds(1));
191 }
192 }
193 }
194
195 virtual void Work(T*) = 0;
196protected:
197
198 const std::vector<T*>& m_Vector;
199 size_t m_VectorSize;
200 std::vector<std::pair<size_t, size_t>> m_IndexAssignments;
201
202 size_t m_NumThreads;
203 std::vector<std::thread> m_Threads;
204
205 bool m_Stop;
206 size_t m_NumComplete;
207 std::vector<bool> m_Process;
208 std::mutex m_Mutex;
209
210
211 void StartThreads()
212 {
213 m_Process = std::vector<bool>(m_NumThreads, false);
214 for(size_t i = 0; i < m_NumThreads; ++i)
215 {
216 m_Threads.push_back(std::thread(&MultiThreadedVectorProcessor::Run, this, i));
217 }
218 }
219
220 void SplitVector()
221 {
222 m_IndexAssignments.clear();
223 m_VectorSize = m_Vector.size();
224
225 size_t length = m_Vector.size() / m_NumThreads;
226 size_t remain = m_Vector.size() % m_NumThreads;
227
228 size_t begin = 0;
229 size_t end = 0;
230
231 for (size_t i = 0; i < std::min(m_NumThreads, m_Vector.size()); ++i)
232 {
233 end += (remain > 0) ? (length + !!(remain--)) : length;
234
235 m_IndexAssignments.push_back(std::make_pair(begin, end));
236
237 begin = end;
238 }
239 }
240};
241*/
242
243/*
244*** MultiThreadedVectorProcessor implementation using std::futures
245*** assigns indices to threads instead of first-come-first-serve
246*** spawns std::asyncs each time instead of full threads
247*** will need to update constructor calls with numThreads and don't need to call start anymore
248template<typename T>
249class MultiThreadedVectorProcessor
250{
251public:
252 MultiThreadedVectorProcessor(const std::vector<T*>& v, size_t numThreads = 1) : m_Vector(v)
253 {
254 m_NumThreads = numThreads;
255 }
256 virtual ~MultiThreadedVectorProcessor()
257 {
258 }
259
260 void SetNumberOfThreads(size_t numThreads)
261 {
262 m_NumThreads = numThreads;
263 }
264
265 void ProcessVectorContents()
266 {
267 const std::vector<std::pair<size_t, size_t>> threadSplit = SplitVector(m_Vector, m_NumThreads);
268
269 std::vector<std::future<void>> futures;
270 for(auto indices: threadSplit)
271 futures.emplace_back(std::async(std::launch::async, [=](){ Run(indices); }));
272
273 for(auto& future: futures)
274 future.get();
275 }
276
277 void Run(std::pair<size_t, size_t> indices)
278 {
279 for(size_t idx = indices.first; idx < indices.second; ++idx)
280 {
281 T* m = m_Vector[idx];
282 if( m != nullptr )
283 {
284 Work(m);
285 }
286 }
287 }
288
289 virtual void Work(T*) = 0;
290protected:
291
292 const std::vector<T*>& m_Vector;
293 size_t m_NumThreads;
294
295 std::vector<std::pair<size_t,size_t>> SplitVector(const std::vector<T*>& vec, size_t n)
296 {
297 std::vector<std::pair<size_t, size_t>> outVec;
298
299 size_t length = vec.size() / n;
300 size_t remain = vec.size() % n;
301
302 size_t begin = 0;
303 size_t end = 0;
304
305 for (size_t i = 0; i < std::min(n, vec.size()); ++i)
306 {
307 end += (remain > 0) ? (length + !!(remain--)) : length;
308
309 outVec.push_back(std::make_pair(begin, end));
310
311 begin = end;
312 }
313
314 return outVec;
315 }
316};
317*/
318
319
320/*
321*** MultiThreadedVectorProcessor implementation using std::condition_variables
322*** assigns indices to threads instead of first-come-first-serve
323*** keeps threads alive as long as possible
324*** will need to update constructor calls with numThreads and don't need to call start anymore
325template<typename T>
326class MultiThreadedVectorProcessor
327{
328public:
329 MultiThreadedVectorProcessor(const std::vector<T*>& v, size_t numThreads = 1) : m_Vector(v)
330 {
331 m_NumThreads = numThreads;
332 m_Stop = false;
333
334 // Assign vector indices to threads and start up threads
335 SplitVector();
336 StartThreads();
337 }
338 virtual ~MultiThreadedVectorProcessor()
339 {
340 Stop();
341 }
342
343 void Stop()
344 {
345 m_Stop = true;
346 for(size_t i = 0; i < m_NumThreads; ++i)
347 {
348 m_CondVar[i].notify_all();
349 m_Threads[i].join();
350 }
351 m_Threads.clear();
352 m_Stop = false;
353 }
354
355 void SetNumberOfThreads(size_t numThreads)
356 {
357 if(m_NumThreads != numThreads)
358 {
359 Stop();
360
361 m_NumThreads = numThreads;
362 SplitVector();
363 StartThreads();
364 }
365 }
366
367 void ProcessVectorContents()
368 {
369 // Ensure the proper number of threads are running
370 if(m_Threads.size() != m_NumThreads)
371 {
372 size_t nThreads = m_NumThreads;
373 m_NumThreads = std::numeric_limits<std::size_t>::max();
374 SetNumberOfThreads(nThreads);
375 }
376
377 // Elements have been added/removed from vector
378 if(m_Vector.size() != m_VectorSize)
379 SplitVector();
380
381 // Tell threads to process
382 for(size_t i = 0; i < m_Process.size(); ++i)
383 {
384 std::lock_guard<std::mutex> guard(m_Mutex[i]);
385 m_Process[i] = true;
386 m_CondVar[i].notify_all();
387 }
388
389 // Wait for all threads to process
390 for(size_t i = 0; i < m_Process.size(); ++i)
391 {
392 std::unique_lock<std::mutex> L{m_Mutex[i]};
393 m_CondVar[i].wait(L, [&]()
394 {
395 // Acquire lock only if we've processed
396 return !m_Process[i];
397 });
398 }
399 }
400
401 void Run(size_t id)
402 {
403 while(true)
404 {
405 // Wait until told to process or told to stop
406 std::unique_lock<std::mutex> L{m_Mutex[id]};
407 m_CondVar[id].wait(L, [&]()
408 {
409 // Acquire lock only if we've stopped or have something to process
410 return m_Stop || m_Process[id];
411 });
412
413 if(m_Stop)
414 return;
415
416 if(m_Process[id])
417 {
418 auto& indices = m_IndexAssignments[id];
419 for(size_t idx = indices.first; idx < indices.second; ++idx)
420 {
421 T* m = m_Vector[idx];
422 if( m != nullptr )
423 {
424 Work(m);
425 }
426 }
427 m_Process[id] = false;
428 m_CondVar[id].notify_all();
429 }
430 }
431 }
432
433 virtual void Work(T*) = 0;
434protected:
435
436 const std::vector<T*>& m_Vector;
437 size_t m_VectorSize;
438 std::vector<std::pair<size_t, size_t>> m_IndexAssignments;
439
440 size_t m_NumThreads;
441 std::vector<std::thread> m_Threads;
442
443 std::vector<bool> m_Process;
444 std::vector<std::mutex> m_Mutex;
445 std::vector<std::condition_variable> m_CondVar;
446 bool m_Stop;
447
448
449 void StartThreads()
450 {
451 m_Process = std::vector<bool>(m_NumThreads, false);
452 m_Mutex = std::vector<std::mutex>(m_NumThreads);
453 m_CondVar = std::vector<std::condition_variable>(m_NumThreads);
454 for(size_t i = 0; i < m_NumThreads; ++i)
455 {
456 m_Threads.push_back(std::thread(&MultiThreadedVectorProcessor::Run, this, i));
457 }
458 }
459
460 void SplitVector()
461 {
462 m_IndexAssignments.clear();
463 m_VectorSize = m_Vector.size();
464
465 size_t length = m_Vector.size() / m_NumThreads;
466 size_t remain = m_Vector.size() % m_NumThreads;
467
468 size_t begin = 0;
469 size_t end = 0;
470
471 for (size_t i = 0; i < std::min(m_NumThreads, m_Vector.size()); ++i)
472 {
473 end += (remain > 0) ? (length + !!(remain--)) : length;
474
475 m_IndexAssignments.push_back(std::make_pair(begin, end));
476
477 begin = end;
478 }
479 }
480};
481*/
482
483
484// Original code https://github.com/progschj/ThreadPool
486public:
487 ThreadPool(size_t = 0);
488 template<class F, class... Args>
489 auto enqueue(F&& f, Args&&... args)
490 ->std::future<typename std::invoke_result<F,Args...>::type>;
491 size_t workerCount() const { return workers.size(); }
492 ~ThreadPool();
493private:
494 // need to keep track of threads so we can join them
495 std::vector< std::thread > workers;
496 // the task queue
497 std::queue< std::function<void()> > tasks;
498
499 // synchronization
500 std::mutex queue_mutex;
501 std::condition_variable condition;
502 bool stop;
503};
504
505// the constructor just launches some amount of workers
506inline ThreadPool::ThreadPool(size_t threads)
507 : stop(false)
508{
509 if (threads == 0)
510 threads = std::thread::hardware_concurrency();
511 for (size_t i = 0; i < threads; ++i)
512 workers.emplace_back(
513 [this]
514 {
515 for (;;)
516 {
517 std::function<void()> task;
518
519 {
520 std::unique_lock<std::mutex> lock(this->queue_mutex);
521 this->condition.wait(lock,
522 [this] { return this->stop || !this->tasks.empty(); });
523 if (this->stop && this->tasks.empty())
524 return;
525 task = std::move(this->tasks.front());
526 this->tasks.pop();
527 }
528
529 task();
530 }
531 }
532 );
533}
534
535// add new work item to the pool
536template<class F, class... Args>
537auto ThreadPool::enqueue(F&& f, Args&&... args)
538-> std::future<typename std::invoke_result<F,Args...>::type>
539{
540 using return_type = typename std::invoke_result<F,Args...>::type;
541
542 auto task = std::make_shared< std::packaged_task<return_type()> >(
543 std::bind(std::forward<F>(f), std::forward<Args>(args)...)
544 );
545
546 std::future<return_type> res = task->get_future();
547 {
548 std::unique_lock<std::mutex> lock(queue_mutex);
549
550 // don't allow enqueueing after stopping the pool
551 if (stop)
552 throw std::runtime_error("enqueue on stopped ThreadPool");
553
554 tasks.emplace([task]() { (*task)(); });
555 }
556 condition.notify_one();
557 return res;
558}
559
560// the destructor joins all threads
562{
563 {
564 std::unique_lock<std::mutex> lock(queue_mutex);
565 stop = true;
566 }
567 condition.notify_all();
568 for (std::thread& worker : workers)
569 worker.join();
570}
Definition: ThreadPool.h:19
MultiThreadedVectorProcessor(const std::vector< T * > &v)
Definition: ThreadPool.h:21
std::vector< std::thread > m_Threads
Definition: ThreadPool.h:94
virtual ~MultiThreadedVectorProcessor()
Definition: ThreadPool.h:27
void ProcessVectorContents()
Definition: ThreadPool.h:50
std::mutex m_Mutex
Definition: ThreadPool.h:93
std::atomic< size_t > m_NextIdx
Definition: ThreadPool.h:87
bool m_Stop
Definition: ThreadPool.h:92
void Start(size_t numThreads)
Definition: ThreadPool.h:32
const std::vector< T * > & m_Vector
Definition: ThreadPool.h:89
void Stop()
Definition: ThreadPool.h:41
virtual void Work(T *)=0
std::atomic< size_t > m_NumComplete
Definition: ThreadPool.h:88
void Run()
Definition: ThreadPool.h:58
Definition: ThreadPool.h:485
bool stop
Definition: ThreadPool.h:502
std::vector< std::thread > workers
Definition: ThreadPool.h:495
ThreadPool(size_t=0)
Definition: ThreadPool.h:506
auto enqueue(F &&f, Args &&... args) -> std::future< typename std::invoke_result< F, Args... >::type >
Definition: ThreadPool.h:537
~ThreadPool()
Definition: ThreadPool.h:561
std::condition_variable condition
Definition: ThreadPool.h:501
std::queue< std::function< void()> > tasks
Definition: ThreadPool.h:497
size_t workerCount() const
Definition: ThreadPool.h:491
std::mutex queue_mutex
Definition: ThreadPool.h:500

Distributed under the Apache License, Version 2.0.

See accompanying NOTICE file for details.