Austin Schuh | 70cc955 | 2019-01-21 19:46:48 -0800 | [diff] [blame^] | 1 | // Ceres Solver - A fast non-linear least squares minimizer |
| 2 | // Copyright 2018 Google Inc. All rights reserved. |
| 3 | // http://ceres-solver.org/ |
| 4 | // |
| 5 | // Redistribution and use in source and binary forms, with or without |
| 6 | // modification, are permitted provided that the following conditions are met: |
| 7 | // |
| 8 | // * Redistributions of source code must retain the above copyright notice, |
| 9 | // this list of conditions and the following disclaimer. |
| 10 | // * Redistributions in binary form must reproduce the above copyright notice, |
| 11 | // this list of conditions and the following disclaimer in the documentation |
| 12 | // and/or other materials provided with the distribution. |
| 13 | // * Neither the name of Google Inc. nor the names of its contributors may be |
| 14 | // used to endorse or promote products derived from this software without |
| 15 | // specific prior written permission. |
| 16 | // |
| 17 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |
| 18 | // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
| 19 | // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE |
| 20 | // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE |
| 21 | // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR |
| 22 | // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF |
| 23 | // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS |
| 24 | // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN |
| 25 | // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) |
| 26 | // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE |
| 27 | // POSSIBILITY OF SUCH DAMAGE. |
| 28 | // |
| 29 | // Author: vitus@google.com (Michael Vitus) |
| 30 | |
| 31 | // This include must come before any #ifndef check on Ceres compile options. |
| 32 | #include "ceres/internal/port.h" |
| 33 | |
| 34 | #ifdef CERES_USE_CXX11_THREADS |
| 35 | |
| 36 | #include "ceres/parallel_for.h" |
| 37 | |
| 38 | #include <cmath> |
| 39 | #include <condition_variable> |
| 40 | #include <memory> |
| 41 | #include <mutex> |
| 42 | |
| 43 | #include "ceres/concurrent_queue.h" |
| 44 | #include "ceres/scoped_thread_token.h" |
| 45 | #include "ceres/thread_token_provider.h" |
| 46 | #include "glog/logging.h" |
| 47 | |
| 48 | namespace ceres { |
| 49 | namespace internal { |
| 50 | namespace { |
| 51 | // This class creates a thread safe barrier which will block until a |
| 52 | // pre-specified number of threads call Finished. This allows us to block the |
| 53 | // main thread until all the parallel threads are finished processing all the |
| 54 | // work. |
| 55 | class BlockUntilFinished { |
| 56 | public: |
| 57 | explicit BlockUntilFinished(int num_total) |
| 58 | : num_finished_(0), num_total_(num_total) {} |
| 59 | |
| 60 | // Increment the number of jobs that have finished and signal the blocking |
| 61 | // thread if all jobs have finished. |
| 62 | void Finished() { |
| 63 | std::lock_guard<std::mutex> lock(mutex_); |
| 64 | ++num_finished_; |
| 65 | CHECK_LE(num_finished_, num_total_); |
| 66 | if (num_finished_ == num_total_) { |
| 67 | condition_.notify_one(); |
| 68 | } |
| 69 | } |
| 70 | |
| 71 | // Block until all threads have signaled they are finished. |
| 72 | void Block() { |
| 73 | std::unique_lock<std::mutex> lock(mutex_); |
| 74 | condition_.wait(lock, [&]() { return num_finished_ == num_total_; }); |
| 75 | } |
| 76 | |
| 77 | private: |
| 78 | std::mutex mutex_; |
| 79 | std::condition_variable condition_; |
| 80 | // The current number of jobs finished. |
| 81 | int num_finished_; |
| 82 | // The total number of jobs. |
| 83 | int num_total_; |
| 84 | }; |
| 85 | |
| 86 | // Shared state between the parallel tasks. Each thread will use this |
| 87 | // information to get the next block of work to be performed. |
| 88 | struct SharedState { |
| 89 | SharedState(int start, int end, int num_work_items) |
| 90 | : start(start), |
| 91 | end(end), |
| 92 | num_work_items(num_work_items), |
| 93 | i(0), |
| 94 | thread_token_provider(num_work_items), |
| 95 | block_until_finished(num_work_items) {} |
| 96 | |
| 97 | // The start and end index of the for loop. |
| 98 | const int start; |
| 99 | const int end; |
| 100 | // The number of blocks that need to be processed. |
| 101 | const int num_work_items; |
| 102 | |
| 103 | // The next block of work to be assigned to a worker. The parallel for loop |
| 104 | // range is split into num_work_items blocks of work, i.e. a single block of |
| 105 | // work is: |
| 106 | // for (int j = start + i; j < end; j += num_work_items) { ... }. |
| 107 | int i; |
| 108 | std::mutex mutex_i; |
| 109 | |
| 110 | // Provides a unique thread ID among all active threads working on the same |
| 111 | // group of tasks. Thread-safe. |
| 112 | ThreadTokenProvider thread_token_provider; |
| 113 | |
| 114 | // Used to signal when all the work has been completed. Thread safe. |
| 115 | BlockUntilFinished block_until_finished; |
| 116 | }; |
| 117 | |
| 118 | } // namespace |
| 119 | |
| 120 | int MaxNumThreadsAvailable() { |
| 121 | return ThreadPool::MaxNumThreadsAvailable(); |
| 122 | } |
| 123 | |
| 124 | // See ParallelFor (below) for more details. |
| 125 | void ParallelFor(ContextImpl* context, |
| 126 | int start, |
| 127 | int end, |
| 128 | int num_threads, |
| 129 | const std::function<void(int)>& function) { |
| 130 | CHECK_GT(num_threads, 0); |
| 131 | CHECK(context != NULL); |
| 132 | if (end <= start) { |
| 133 | return; |
| 134 | } |
| 135 | |
| 136 | // Fast path for when it is single threaded. |
| 137 | if (num_threads == 1) { |
| 138 | for (int i = start; i < end; ++i) { |
| 139 | function(i); |
| 140 | } |
| 141 | return; |
| 142 | } |
| 143 | |
| 144 | ParallelFor(context, start, end, num_threads, |
| 145 | [&function](int /*thread_id*/, int i) { function(i); }); |
| 146 | } |
| 147 | |
| 148 | // This implementation uses a fixed size max worker pool with a shared task |
| 149 | // queue. The problem of executing the function for the interval of [start, end) |
| 150 | // is broken up into at most num_threads blocks and added to the thread pool. To |
| 151 | // avoid deadlocks, the calling thread is allowed to steal work from the worker |
| 152 | // pool. This is implemented via a shared state between the tasks. In order for |
| 153 | // the calling thread or thread pool to get a block of work, it will query the |
| 154 | // shared state for the next block of work to be done. If there is nothing left, |
| 155 | // it will return. We will exit the ParallelFor call when all of the work has |
| 156 | // been done, not when all of the tasks have been popped off the task queue. |
| 157 | // |
| 158 | // A unique thread ID among all active tasks will be acquired once for each |
| 159 | // block of work. This avoids the significant performance penalty for acquiring |
| 160 | // it on every iteration of the for loop. The thread ID is guaranteed to be in |
| 161 | // [0, num_threads). |
| 162 | // |
| 163 | // A performance analysis has shown this implementation is onpar with OpenMP and |
| 164 | // TBB. |
| 165 | void ParallelFor(ContextImpl* context, |
| 166 | int start, |
| 167 | int end, |
| 168 | int num_threads, |
| 169 | const std::function<void(int thread_id, int i)>& function) { |
| 170 | CHECK_GT(num_threads, 0); |
| 171 | CHECK(context != NULL); |
| 172 | if (end <= start) { |
| 173 | return; |
| 174 | } |
| 175 | |
| 176 | // Fast path for when it is single threaded. |
| 177 | if (num_threads == 1) { |
| 178 | // Even though we only have one thread, use the thread token provider to |
| 179 | // guarantee the exact same behavior when running with multiple threads. |
| 180 | ThreadTokenProvider thread_token_provider(num_threads); |
| 181 | const ScopedThreadToken scoped_thread_token(&thread_token_provider); |
| 182 | const int thread_id = scoped_thread_token.token(); |
| 183 | for (int i = start; i < end; ++i) { |
| 184 | function(thread_id, i); |
| 185 | } |
| 186 | return; |
| 187 | } |
| 188 | |
| 189 | // We use a std::shared_ptr because the main thread can finish all |
| 190 | // the work before the tasks have been popped off the queue. So the |
| 191 | // shared state needs to exist for the duration of all the tasks. |
| 192 | const int num_work_items = std::min((end - start), num_threads); |
| 193 | std::shared_ptr<SharedState> shared_state( |
| 194 | new SharedState(start, end, num_work_items)); |
| 195 | |
| 196 | // A function which tries to perform a chunk of work. This returns false if |
| 197 | // there is no work to be done. |
| 198 | auto task_function = [shared_state, &function]() { |
| 199 | int i = 0; |
| 200 | { |
| 201 | // Get the next available chunk of work to be performed. If there is no |
| 202 | // work, return false. |
| 203 | std::lock_guard<std::mutex> lock(shared_state->mutex_i); |
| 204 | if (shared_state->i >= shared_state->num_work_items) { |
| 205 | return false; |
| 206 | } |
| 207 | i = shared_state->i; |
| 208 | ++shared_state->i; |
| 209 | } |
| 210 | |
| 211 | const ScopedThreadToken scoped_thread_token( |
| 212 | &shared_state->thread_token_provider); |
| 213 | const int thread_id = scoped_thread_token.token(); |
| 214 | |
| 215 | // Perform each task. |
| 216 | for (int j = shared_state->start + i; |
| 217 | j < shared_state->end; |
| 218 | j += shared_state->num_work_items) { |
| 219 | function(thread_id, j); |
| 220 | } |
| 221 | shared_state->block_until_finished.Finished(); |
| 222 | return true; |
| 223 | }; |
| 224 | |
| 225 | // Add all the tasks to the thread pool. |
| 226 | for (int i = 0; i < num_work_items; ++i) { |
| 227 | // Note we are taking the task_function as value so the shared_state |
| 228 | // shared pointer is copied and the ref count is increased. This is to |
| 229 | // prevent it from being deleted when the main thread finishes all the |
| 230 | // work and exits before the threads finish. |
| 231 | context->thread_pool.AddTask([task_function]() { task_function(); }); |
| 232 | } |
| 233 | |
| 234 | // Try to do any available work on the main thread. This may steal work from |
| 235 | // the thread pool, but when there is no work left the thread pool tasks |
| 236 | // will be no-ops. |
| 237 | while (task_function()) { |
| 238 | } |
| 239 | |
| 240 | // Wait until all tasks have finished. |
| 241 | shared_state->block_until_finished.Block(); |
| 242 | } |
| 243 | |
| 244 | } // namespace internal |
| 245 | } // namespace ceres |
| 246 | |
| 247 | #endif // CERES_USE_CXX11_THREADS |