Austin Schuh | 3de38b0 | 2024-06-25 18:25:10 -0700 | [diff] [blame^] | 1 | // Ceres Solver - A fast non-linear least squares minimizer |
| 2 | // Copyright 2023 Google Inc. All rights reserved. |
| 3 | // http://ceres-solver.org/ |
| 4 | // |
| 5 | // Redistribution and use in source and binary forms, with or without |
| 6 | // modification, are permitted provided that the following conditions are met: |
| 7 | // |
| 8 | // * Redistributions of source code must retain the above copyright notice, |
| 9 | // this list of conditions and the following disclaimer. |
| 10 | // * Redistributions in binary form must reproduce the above copyright notice, |
| 11 | // this list of conditions and the following disclaimer in the documentation |
| 12 | // and/or other materials provided with the distribution. |
| 13 | // * Neither the name of Google Inc. nor the names of its contributors may be |
| 14 | // used to endorse or promote products derived from this software without |
| 15 | // specific prior written permission. |
| 16 | // |
| 17 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |
| 18 | // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
| 19 | // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE |
| 20 | // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE |
| 21 | // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR |
| 22 | // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF |
| 23 | // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS |
| 24 | // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN |
| 25 | // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) |
| 26 | // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE |
| 27 | // POSSIBILITY OF SUCH DAMAGE. |
| 28 | // |
| 29 | // Authors: vitus@google.com (Michael Vitus), |
| 30 | // dmitriy.korchemkin@gmail.com (Dmitriy Korchemkin) |
| 31 | |
| 32 | #ifndef CERES_INTERNAL_PARALLEL_INVOKE_H_ |
| 33 | #define CERES_INTERNAL_PARALLEL_INVOKE_H_ |
| 34 | |
| 35 | #include <atomic> |
| 36 | #include <condition_variable> |
| 37 | #include <memory> |
| 38 | #include <mutex> |
| 39 | #include <tuple> |
| 40 | #include <type_traits> |
| 41 | |
| 42 | namespace ceres::internal { |
| 43 | |
| 44 | // InvokeWithThreadId handles passing thread_id to the function |
| 45 | template <typename F, typename... Args> |
| 46 | void InvokeWithThreadId(int thread_id, F&& function, Args&&... args) { |
| 47 | constexpr bool kPassThreadId = std::is_invocable_v<F, int, Args...>; |
| 48 | |
| 49 | if constexpr (kPassThreadId) { |
| 50 | function(thread_id, std::forward<Args>(args)...); |
| 51 | } else { |
| 52 | function(std::forward<Args>(args)...); |
| 53 | } |
| 54 | } |
| 55 | |
| 56 | // InvokeOnSegment either runs a loop over segment indices or passes it to the |
| 57 | // function |
| 58 | template <typename F> |
| 59 | void InvokeOnSegment(int thread_id, std::tuple<int, int> range, F&& function) { |
| 60 | constexpr bool kExplicitLoop = |
| 61 | std::is_invocable_v<F, int> || std::is_invocable_v<F, int, int>; |
| 62 | |
| 63 | if constexpr (kExplicitLoop) { |
| 64 | const auto [start, end] = range; |
| 65 | for (int i = start; i != end; ++i) { |
| 66 | InvokeWithThreadId(thread_id, std::forward<F>(function), i); |
| 67 | } |
| 68 | } else { |
| 69 | InvokeWithThreadId(thread_id, std::forward<F>(function), range); |
| 70 | } |
| 71 | } |
| 72 | |
| 73 | // This class creates a thread safe barrier which will block until a |
| 74 | // pre-specified number of threads call Finished. This allows us to block the |
| 75 | // main thread until all the parallel threads are finished processing all the |
| 76 | // work. |
| 77 | class BlockUntilFinished { |
| 78 | public: |
| 79 | explicit BlockUntilFinished(int num_total_jobs); |
| 80 | |
| 81 | // Increment the number of jobs that have been processed by the number of |
| 82 | // jobs processed by caller and signal the blocking thread if all jobs |
| 83 | // have finished. |
| 84 | void Finished(int num_jobs_finished); |
| 85 | |
| 86 | // Block until receiving confirmation of all jobs being finished. |
| 87 | void Block(); |
| 88 | |
| 89 | private: |
| 90 | std::mutex mutex_; |
| 91 | std::condition_variable condition_; |
| 92 | int num_total_jobs_finished_; |
| 93 | const int num_total_jobs_; |
| 94 | }; |
| 95 | |
| 96 | // Shared state between the parallel tasks. Each thread will use this |
| 97 | // information to get the next block of work to be performed. |
| 98 | struct ParallelInvokeState { |
| 99 | // The entire range [start, end) is split into num_work_blocks contiguous |
| 100 | // disjoint intervals (blocks), which are as equal as possible given |
| 101 | // total index count and requested number of blocks. |
| 102 | // |
| 103 | // Those num_work_blocks blocks are then processed in parallel. |
| 104 | // |
| 105 | // Total number of integer indices in interval [start, end) is |
| 106 | // end - start, and when splitting them into num_work_blocks blocks |
| 107 | // we can either |
| 108 | // - Split into equal blocks when (end - start) is divisible by |
| 109 | // num_work_blocks |
| 110 | // - Split into blocks with size difference at most 1: |
| 111 | // - Size of the smallest block(s) is (end - start) / num_work_blocks |
| 112 | // - (end - start) % num_work_blocks will need to be 1 index larger |
| 113 | // |
| 114 | // Note that this splitting is optimal in the sense of maximal difference |
| 115 | // between block sizes, since splitting into equal blocks is possible |
| 116 | // if and only if number of indices is divisible by number of blocks. |
| 117 | ParallelInvokeState(int start, int end, int num_work_blocks); |
| 118 | |
| 119 | // The start and end index of the for loop. |
| 120 | const int start; |
| 121 | const int end; |
| 122 | // The number of blocks that need to be processed. |
| 123 | const int num_work_blocks; |
| 124 | // Size of the smallest block |
| 125 | const int base_block_size; |
| 126 | // Number of blocks of size base_block_size + 1 |
| 127 | const int num_base_p1_sized_blocks; |
| 128 | |
| 129 | // The next block of work to be assigned to a worker. The parallel for loop |
| 130 | // range is split into num_work_blocks blocks of work, with a single block of |
| 131 | // work being of size |
| 132 | // - base_block_size + 1 for the first num_base_p1_sized_blocks blocks |
| 133 | // - base_block_size for the rest of the blocks |
| 134 | // blocks of indices are contiguous and disjoint |
| 135 | std::atomic<int> block_id; |
| 136 | |
| 137 | // Provides a unique thread ID among all active threads |
| 138 | // We do not schedule more than num_threads threads via thread pool |
| 139 | // and caller thread might steal one ID |
| 140 | std::atomic<int> thread_id; |
| 141 | |
| 142 | // Used to signal when all the work has been completed. Thread safe. |
| 143 | BlockUntilFinished block_until_finished; |
| 144 | }; |
| 145 | |
| 146 | // This implementation uses a fixed size max worker pool with a shared task |
| 147 | // queue. The problem of executing the function for the interval of [start, end) |
| 148 | // is broken up into at most num_threads * kWorkBlocksPerThread blocks (each of |
| 149 | // size at least min_block_size) and added to the thread pool. To avoid |
| 150 | // deadlocks, the calling thread is allowed to steal work from the worker pool. |
| 151 | // This is implemented via a shared state between the tasks. In order for |
| 152 | // the calling thread or thread pool to get a block of work, it will query the |
| 153 | // shared state for the next block of work to be done. If there is nothing left, |
| 154 | // it will return. We will exit the ParallelFor call when all of the work has |
| 155 | // been done, not when all of the tasks have been popped off the task queue. |
| 156 | // |
| 157 | // A unique thread ID among all active tasks will be acquired once for each |
| 158 | // block of work. This avoids the significant performance penalty for acquiring |
| 159 | // it on every iteration of the for loop. The thread ID is guaranteed to be in |
| 160 | // [0, num_threads). |
| 161 | // |
| 162 | // A performance analysis has shown this implementation is on par with OpenMP |
| 163 | // and TBB. |
| 164 | template <typename F> |
| 165 | void ParallelInvoke(ContextImpl* context, |
| 166 | int start, |
| 167 | int end, |
| 168 | int num_threads, |
| 169 | F&& function, |
| 170 | int min_block_size) { |
| 171 | CHECK(context != nullptr); |
| 172 | |
| 173 | // Maximal number of work items scheduled for a single thread |
| 174 | // - Lower number of work items results in larger runtimes on unequal tasks |
| 175 | // - Higher number of work items results in larger losses for synchronization |
| 176 | constexpr int kWorkBlocksPerThread = 4; |
| 177 | |
| 178 | // Interval [start, end) is being split into |
| 179 | // num_threads * kWorkBlocksPerThread contiguous disjoint blocks. |
| 180 | // |
| 181 | // In order to avoid creating empty blocks of work, we need to limit |
| 182 | // number of work blocks by a total number of indices. |
| 183 | const int num_work_blocks = std::min((end - start) / min_block_size, |
| 184 | num_threads * kWorkBlocksPerThread); |
| 185 | |
| 186 | // We use a std::shared_ptr because the main thread can finish all |
| 187 | // the work before the tasks have been popped off the queue. So the |
| 188 | // shared state needs to exist for the duration of all the tasks. |
| 189 | auto shared_state = |
| 190 | std::make_shared<ParallelInvokeState>(start, end, num_work_blocks); |
| 191 | |
| 192 | // A function which tries to schedule another task in the thread pool and |
| 193 | // perform several chunks of work. Function expects itself as the argument in |
| 194 | // order to schedule next task in the thread pool. |
| 195 | auto task = [context, shared_state, num_threads, &function](auto& task_copy) { |
| 196 | int num_jobs_finished = 0; |
| 197 | const int thread_id = shared_state->thread_id.fetch_add(1); |
| 198 | // In order to avoid dead-locks in nested parallel for loops, task() will be |
| 199 | // invoked num_threads + 1 times: |
| 200 | // - num_threads times via enqueueing task into thread pool |
| 201 | // - one more time in the main thread |
| 202 | // Tasks enqueued to thread pool might take some time before execution, and |
| 203 | // the last task being executed will be terminated here in order to avoid |
| 204 | // having more than num_threads active threads |
| 205 | if (thread_id >= num_threads) return; |
| 206 | const int num_work_blocks = shared_state->num_work_blocks; |
| 207 | if (thread_id + 1 < num_threads && |
| 208 | shared_state->block_id < num_work_blocks) { |
| 209 | // Add another thread to the thread pool. |
| 210 | // Note we are taking the task as value so the copy of shared_state shared |
| 211 | // pointer (captured by value at declaration of task lambda-function) is |
| 212 | // copied and the ref count is increased. This is to prevent it from being |
| 213 | // deleted when the main thread finishes all the work and exits before the |
| 214 | // threads finish. |
| 215 | context->thread_pool.AddTask([task_copy]() { task_copy(task_copy); }); |
| 216 | } |
| 217 | |
| 218 | const int start = shared_state->start; |
| 219 | const int base_block_size = shared_state->base_block_size; |
| 220 | const int num_base_p1_sized_blocks = shared_state->num_base_p1_sized_blocks; |
| 221 | |
| 222 | while (true) { |
| 223 | // Get the next available chunk of work to be performed. If there is no |
| 224 | // work, return. |
| 225 | int block_id = shared_state->block_id.fetch_add(1); |
| 226 | if (block_id >= num_work_blocks) { |
| 227 | break; |
| 228 | } |
| 229 | ++num_jobs_finished; |
| 230 | |
| 231 | // For-loop interval [start, end) was split into num_work_blocks, |
| 232 | // with num_base_p1_sized_blocks of size base_block_size + 1 and remaining |
| 233 | // num_work_blocks - num_base_p1_sized_blocks of size base_block_size |
| 234 | // |
| 235 | // Then, start index of the block #block_id is given by a total |
| 236 | // length of preceeding blocks: |
| 237 | // * Total length of preceeding blocks of size base_block_size + 1: |
| 238 | // min(block_id, num_base_p1_sized_blocks) * (base_block_size + 1) |
| 239 | // |
| 240 | // * Total length of preceeding blocks of size base_block_size: |
| 241 | // (block_id - min(block_id, num_base_p1_sized_blocks)) * |
| 242 | // base_block_size |
| 243 | // |
| 244 | // Simplifying sum of those quantities yields a following |
| 245 | // expression for start index of the block #block_id |
| 246 | const int curr_start = start + block_id * base_block_size + |
| 247 | std::min(block_id, num_base_p1_sized_blocks); |
| 248 | // First num_base_p1_sized_blocks have size base_block_size + 1 |
| 249 | // |
| 250 | // Note that it is guaranteed that all blocks are within |
| 251 | // [start, end) interval |
| 252 | const int curr_end = curr_start + base_block_size + |
| 253 | (block_id < num_base_p1_sized_blocks ? 1 : 0); |
| 254 | // Perform each task in current block |
| 255 | const auto range = std::make_tuple(curr_start, curr_end); |
| 256 | InvokeOnSegment(thread_id, range, function); |
| 257 | } |
| 258 | shared_state->block_until_finished.Finished(num_jobs_finished); |
| 259 | }; |
| 260 | |
| 261 | // Start scheduling threads and doing work. We might end up with less threads |
| 262 | // scheduled than expected, if scheduling overhead is larger than the amount |
| 263 | // of work to be done. |
| 264 | task(task); |
| 265 | |
| 266 | // Wait until all tasks have finished. |
| 267 | shared_state->block_until_finished.Block(); |
| 268 | } |
| 269 | |
| 270 | } // namespace ceres::internal |
| 271 | |
| 272 | #endif |