blob: 4da40c01eb63a98b4f1290501f2dd2825b35f727 [file] [log] [blame]
Austin Schuh70cc9552019-01-21 19:46:48 -08001// Ceres Solver - A fast non-linear least squares minimizer
2// Copyright 2018 Google Inc. All rights reserved.
3// http://ceres-solver.org/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Author: vitus@google.com (Michael Vitus)
30
31// This include must come before any #ifndef check on Ceres compile options.
32#include "ceres/internal/port.h"
33
Austin Schuh1d1e6ea2020-12-23 21:56:30 -080034#ifdef CERES_USE_CXX_THREADS
Austin Schuh70cc9552019-01-21 19:46:48 -080035
36#include <cmath>
37#include <condition_variable>
38#include <memory>
39#include <mutex>
40
41#include "ceres/concurrent_queue.h"
Austin Schuh1d1e6ea2020-12-23 21:56:30 -080042#include "ceres/parallel_for.h"
Austin Schuh70cc9552019-01-21 19:46:48 -080043#include "ceres/scoped_thread_token.h"
44#include "ceres/thread_token_provider.h"
45#include "glog/logging.h"
46
47namespace ceres {
48namespace internal {
49namespace {
50// This class creates a thread safe barrier which will block until a
51// pre-specified number of threads call Finished. This allows us to block the
52// main thread until all the parallel threads are finished processing all the
53// work.
54class BlockUntilFinished {
55 public:
56 explicit BlockUntilFinished(int num_total)
57 : num_finished_(0), num_total_(num_total) {}
58
59 // Increment the number of jobs that have finished and signal the blocking
60 // thread if all jobs have finished.
61 void Finished() {
62 std::lock_guard<std::mutex> lock(mutex_);
63 ++num_finished_;
64 CHECK_LE(num_finished_, num_total_);
65 if (num_finished_ == num_total_) {
66 condition_.notify_one();
67 }
68 }
69
70 // Block until all threads have signaled they are finished.
71 void Block() {
72 std::unique_lock<std::mutex> lock(mutex_);
73 condition_.wait(lock, [&]() { return num_finished_ == num_total_; });
74 }
75
76 private:
77 std::mutex mutex_;
78 std::condition_variable condition_;
79 // The current number of jobs finished.
80 int num_finished_;
81 // The total number of jobs.
82 int num_total_;
83};
84
85// Shared state between the parallel tasks. Each thread will use this
86// information to get the next block of work to be performed.
87struct SharedState {
88 SharedState(int start, int end, int num_work_items)
89 : start(start),
90 end(end),
91 num_work_items(num_work_items),
92 i(0),
93 thread_token_provider(num_work_items),
94 block_until_finished(num_work_items) {}
95
96 // The start and end index of the for loop.
97 const int start;
98 const int end;
99 // The number of blocks that need to be processed.
100 const int num_work_items;
101
102 // The next block of work to be assigned to a worker. The parallel for loop
103 // range is split into num_work_items blocks of work, i.e. a single block of
104 // work is:
105 // for (int j = start + i; j < end; j += num_work_items) { ... }.
106 int i;
107 std::mutex mutex_i;
108
109 // Provides a unique thread ID among all active threads working on the same
110 // group of tasks. Thread-safe.
111 ThreadTokenProvider thread_token_provider;
112
113 // Used to signal when all the work has been completed. Thread safe.
114 BlockUntilFinished block_until_finished;
115};
116
117} // namespace
118
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800119int MaxNumThreadsAvailable() { return ThreadPool::MaxNumThreadsAvailable(); }
Austin Schuh70cc9552019-01-21 19:46:48 -0800120
121// See ParallelFor (below) for more details.
122void ParallelFor(ContextImpl* context,
123 int start,
124 int end,
125 int num_threads,
126 const std::function<void(int)>& function) {
127 CHECK_GT(num_threads, 0);
128 CHECK(context != NULL);
129 if (end <= start) {
130 return;
131 }
132
133 // Fast path for when it is single threaded.
134 if (num_threads == 1) {
135 for (int i = start; i < end; ++i) {
136 function(i);
137 }
138 return;
139 }
140
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800141 ParallelFor(
142 context, start, end, num_threads, [&function](int /*thread_id*/, int i) {
143 function(i);
144 });
Austin Schuh70cc9552019-01-21 19:46:48 -0800145}
146
147// This implementation uses a fixed size max worker pool with a shared task
148// queue. The problem of executing the function for the interval of [start, end)
149// is broken up into at most num_threads blocks and added to the thread pool. To
150// avoid deadlocks, the calling thread is allowed to steal work from the worker
151// pool. This is implemented via a shared state between the tasks. In order for
152// the calling thread or thread pool to get a block of work, it will query the
153// shared state for the next block of work to be done. If there is nothing left,
154// it will return. We will exit the ParallelFor call when all of the work has
155// been done, not when all of the tasks have been popped off the task queue.
156//
157// A unique thread ID among all active tasks will be acquired once for each
158// block of work. This avoids the significant performance penalty for acquiring
159// it on every iteration of the for loop. The thread ID is guaranteed to be in
160// [0, num_threads).
161//
162// A performance analysis has shown this implementation is onpar with OpenMP and
163// TBB.
164void ParallelFor(ContextImpl* context,
165 int start,
166 int end,
167 int num_threads,
168 const std::function<void(int thread_id, int i)>& function) {
169 CHECK_GT(num_threads, 0);
170 CHECK(context != NULL);
171 if (end <= start) {
172 return;
173 }
174
175 // Fast path for when it is single threaded.
176 if (num_threads == 1) {
177 // Even though we only have one thread, use the thread token provider to
178 // guarantee the exact same behavior when running with multiple threads.
179 ThreadTokenProvider thread_token_provider(num_threads);
180 const ScopedThreadToken scoped_thread_token(&thread_token_provider);
181 const int thread_id = scoped_thread_token.token();
182 for (int i = start; i < end; ++i) {
183 function(thread_id, i);
184 }
185 return;
186 }
187
188 // We use a std::shared_ptr because the main thread can finish all
189 // the work before the tasks have been popped off the queue. So the
190 // shared state needs to exist for the duration of all the tasks.
191 const int num_work_items = std::min((end - start), num_threads);
192 std::shared_ptr<SharedState> shared_state(
193 new SharedState(start, end, num_work_items));
194
195 // A function which tries to perform a chunk of work. This returns false if
196 // there is no work to be done.
197 auto task_function = [shared_state, &function]() {
198 int i = 0;
199 {
200 // Get the next available chunk of work to be performed. If there is no
201 // work, return false.
202 std::lock_guard<std::mutex> lock(shared_state->mutex_i);
203 if (shared_state->i >= shared_state->num_work_items) {
204 return false;
205 }
206 i = shared_state->i;
207 ++shared_state->i;
208 }
209
210 const ScopedThreadToken scoped_thread_token(
211 &shared_state->thread_token_provider);
212 const int thread_id = scoped_thread_token.token();
213
214 // Perform each task.
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800215 for (int j = shared_state->start + i; j < shared_state->end;
Austin Schuh70cc9552019-01-21 19:46:48 -0800216 j += shared_state->num_work_items) {
217 function(thread_id, j);
218 }
219 shared_state->block_until_finished.Finished();
220 return true;
221 };
222
223 // Add all the tasks to the thread pool.
224 for (int i = 0; i < num_work_items; ++i) {
225 // Note we are taking the task_function as value so the shared_state
226 // shared pointer is copied and the ref count is increased. This is to
227 // prevent it from being deleted when the main thread finishes all the
228 // work and exits before the threads finish.
229 context->thread_pool.AddTask([task_function]() { task_function(); });
230 }
231
232 // Try to do any available work on the main thread. This may steal work from
233 // the thread pool, but when there is no work left the thread pool tasks
234 // will be no-ops.
235 while (task_function()) {
236 }
237
238 // Wait until all tasks have finished.
239 shared_state->block_until_finished.Block();
240}
241
242} // namespace internal
243} // namespace ceres
244
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800245#endif // CERES_USE_CXX_THREADS