blob: 398f8f28f3dc953adf47cb806802edad0bcf5250 [file] [log] [blame]
Austin Schuh3de38b02024-06-25 18:25:10 -07001// Ceres Solver - A fast non-linear least squares minimizer
2// Copyright 2023 Google Inc. All rights reserved.
3// http://ceres-solver.org/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Authors: vitus@google.com (Michael Vitus),
30// dmitriy.korchemkin@gmail.com (Dmitriy Korchemkin)
31
32#ifndef CERES_INTERNAL_PARALLEL_INVOKE_H_
33#define CERES_INTERNAL_PARALLEL_INVOKE_H_
34
35#include <atomic>
36#include <condition_variable>
37#include <memory>
38#include <mutex>
39#include <tuple>
40#include <type_traits>
41
42namespace ceres::internal {
43
44// InvokeWithThreadId handles passing thread_id to the function
45template <typename F, typename... Args>
46void InvokeWithThreadId(int thread_id, F&& function, Args&&... args) {
47 constexpr bool kPassThreadId = std::is_invocable_v<F, int, Args...>;
48
49 if constexpr (kPassThreadId) {
50 function(thread_id, std::forward<Args>(args)...);
51 } else {
52 function(std::forward<Args>(args)...);
53 }
54}
55
56// InvokeOnSegment either runs a loop over segment indices or passes it to the
57// function
58template <typename F>
59void InvokeOnSegment(int thread_id, std::tuple<int, int> range, F&& function) {
60 constexpr bool kExplicitLoop =
61 std::is_invocable_v<F, int> || std::is_invocable_v<F, int, int>;
62
63 if constexpr (kExplicitLoop) {
64 const auto [start, end] = range;
65 for (int i = start; i != end; ++i) {
66 InvokeWithThreadId(thread_id, std::forward<F>(function), i);
67 }
68 } else {
69 InvokeWithThreadId(thread_id, std::forward<F>(function), range);
70 }
71}
72
73// This class creates a thread safe barrier which will block until a
74// pre-specified number of threads call Finished. This allows us to block the
75// main thread until all the parallel threads are finished processing all the
76// work.
77class BlockUntilFinished {
78 public:
79 explicit BlockUntilFinished(int num_total_jobs);
80
81 // Increment the number of jobs that have been processed by the number of
82 // jobs processed by caller and signal the blocking thread if all jobs
83 // have finished.
84 void Finished(int num_jobs_finished);
85
86 // Block until receiving confirmation of all jobs being finished.
87 void Block();
88
89 private:
90 std::mutex mutex_;
91 std::condition_variable condition_;
92 int num_total_jobs_finished_;
93 const int num_total_jobs_;
94};
95
96// Shared state between the parallel tasks. Each thread will use this
97// information to get the next block of work to be performed.
98struct ParallelInvokeState {
99 // The entire range [start, end) is split into num_work_blocks contiguous
100 // disjoint intervals (blocks), which are as equal as possible given
101 // total index count and requested number of blocks.
102 //
103 // Those num_work_blocks blocks are then processed in parallel.
104 //
105 // Total number of integer indices in interval [start, end) is
106 // end - start, and when splitting them into num_work_blocks blocks
107 // we can either
108 // - Split into equal blocks when (end - start) is divisible by
109 // num_work_blocks
110 // - Split into blocks with size difference at most 1:
111 // - Size of the smallest block(s) is (end - start) / num_work_blocks
112 // - (end - start) % num_work_blocks will need to be 1 index larger
113 //
114 // Note that this splitting is optimal in the sense of maximal difference
115 // between block sizes, since splitting into equal blocks is possible
116 // if and only if number of indices is divisible by number of blocks.
117 ParallelInvokeState(int start, int end, int num_work_blocks);
118
119 // The start and end index of the for loop.
120 const int start;
121 const int end;
122 // The number of blocks that need to be processed.
123 const int num_work_blocks;
124 // Size of the smallest block
125 const int base_block_size;
126 // Number of blocks of size base_block_size + 1
127 const int num_base_p1_sized_blocks;
128
129 // The next block of work to be assigned to a worker. The parallel for loop
130 // range is split into num_work_blocks blocks of work, with a single block of
131 // work being of size
132 // - base_block_size + 1 for the first num_base_p1_sized_blocks blocks
133 // - base_block_size for the rest of the blocks
134 // blocks of indices are contiguous and disjoint
135 std::atomic<int> block_id;
136
137 // Provides a unique thread ID among all active threads
138 // We do not schedule more than num_threads threads via thread pool
139 // and caller thread might steal one ID
140 std::atomic<int> thread_id;
141
142 // Used to signal when all the work has been completed. Thread safe.
143 BlockUntilFinished block_until_finished;
144};
145
146// This implementation uses a fixed size max worker pool with a shared task
147// queue. The problem of executing the function for the interval of [start, end)
148// is broken up into at most num_threads * kWorkBlocksPerThread blocks (each of
149// size at least min_block_size) and added to the thread pool. To avoid
150// deadlocks, the calling thread is allowed to steal work from the worker pool.
151// This is implemented via a shared state between the tasks. In order for
152// the calling thread or thread pool to get a block of work, it will query the
153// shared state for the next block of work to be done. If there is nothing left,
154// it will return. We will exit the ParallelFor call when all of the work has
155// been done, not when all of the tasks have been popped off the task queue.
156//
157// A unique thread ID among all active tasks will be acquired once for each
158// block of work. This avoids the significant performance penalty for acquiring
159// it on every iteration of the for loop. The thread ID is guaranteed to be in
160// [0, num_threads).
161//
162// A performance analysis has shown this implementation is on par with OpenMP
163// and TBB.
164template <typename F>
165void ParallelInvoke(ContextImpl* context,
166 int start,
167 int end,
168 int num_threads,
169 F&& function,
170 int min_block_size) {
171 CHECK(context != nullptr);
172
173 // Maximal number of work items scheduled for a single thread
174 // - Lower number of work items results in larger runtimes on unequal tasks
175 // - Higher number of work items results in larger losses for synchronization
176 constexpr int kWorkBlocksPerThread = 4;
177
178 // Interval [start, end) is being split into
179 // num_threads * kWorkBlocksPerThread contiguous disjoint blocks.
180 //
181 // In order to avoid creating empty blocks of work, we need to limit
182 // number of work blocks by a total number of indices.
183 const int num_work_blocks = std::min((end - start) / min_block_size,
184 num_threads * kWorkBlocksPerThread);
185
186 // We use a std::shared_ptr because the main thread can finish all
187 // the work before the tasks have been popped off the queue. So the
188 // shared state needs to exist for the duration of all the tasks.
189 auto shared_state =
190 std::make_shared<ParallelInvokeState>(start, end, num_work_blocks);
191
192 // A function which tries to schedule another task in the thread pool and
193 // perform several chunks of work. Function expects itself as the argument in
194 // order to schedule next task in the thread pool.
195 auto task = [context, shared_state, num_threads, &function](auto& task_copy) {
196 int num_jobs_finished = 0;
197 const int thread_id = shared_state->thread_id.fetch_add(1);
198 // In order to avoid dead-locks in nested parallel for loops, task() will be
199 // invoked num_threads + 1 times:
200 // - num_threads times via enqueueing task into thread pool
201 // - one more time in the main thread
202 // Tasks enqueued to thread pool might take some time before execution, and
203 // the last task being executed will be terminated here in order to avoid
204 // having more than num_threads active threads
205 if (thread_id >= num_threads) return;
206 const int num_work_blocks = shared_state->num_work_blocks;
207 if (thread_id + 1 < num_threads &&
208 shared_state->block_id < num_work_blocks) {
209 // Add another thread to the thread pool.
210 // Note we are taking the task as value so the copy of shared_state shared
211 // pointer (captured by value at declaration of task lambda-function) is
212 // copied and the ref count is increased. This is to prevent it from being
213 // deleted when the main thread finishes all the work and exits before the
214 // threads finish.
215 context->thread_pool.AddTask([task_copy]() { task_copy(task_copy); });
216 }
217
218 const int start = shared_state->start;
219 const int base_block_size = shared_state->base_block_size;
220 const int num_base_p1_sized_blocks = shared_state->num_base_p1_sized_blocks;
221
222 while (true) {
223 // Get the next available chunk of work to be performed. If there is no
224 // work, return.
225 int block_id = shared_state->block_id.fetch_add(1);
226 if (block_id >= num_work_blocks) {
227 break;
228 }
229 ++num_jobs_finished;
230
231 // For-loop interval [start, end) was split into num_work_blocks,
232 // with num_base_p1_sized_blocks of size base_block_size + 1 and remaining
233 // num_work_blocks - num_base_p1_sized_blocks of size base_block_size
234 //
235 // Then, start index of the block #block_id is given by a total
236 // length of preceeding blocks:
237 // * Total length of preceeding blocks of size base_block_size + 1:
238 // min(block_id, num_base_p1_sized_blocks) * (base_block_size + 1)
239 //
240 // * Total length of preceeding blocks of size base_block_size:
241 // (block_id - min(block_id, num_base_p1_sized_blocks)) *
242 // base_block_size
243 //
244 // Simplifying sum of those quantities yields a following
245 // expression for start index of the block #block_id
246 const int curr_start = start + block_id * base_block_size +
247 std::min(block_id, num_base_p1_sized_blocks);
248 // First num_base_p1_sized_blocks have size base_block_size + 1
249 //
250 // Note that it is guaranteed that all blocks are within
251 // [start, end) interval
252 const int curr_end = curr_start + base_block_size +
253 (block_id < num_base_p1_sized_blocks ? 1 : 0);
254 // Perform each task in current block
255 const auto range = std::make_tuple(curr_start, curr_end);
256 InvokeOnSegment(thread_id, range, function);
257 }
258 shared_state->block_until_finished.Finished(num_jobs_finished);
259 };
260
261 // Start scheduling threads and doing work. We might end up with less threads
262 // scheduled than expected, if scheduling overhead is larger than the amount
263 // of work to be done.
264 task(task);
265
266 // Wait until all tasks have finished.
267 shared_state->block_until_finished.Block();
268}
269
270} // namespace ceres::internal
271
272#endif