blob: 11db1fbc4886c5146bf2c45dea71039a6c7556cd [file] [log] [blame]
Austin Schuh70cc9552019-01-21 19:46:48 -08001// Ceres Solver - A fast non-linear least squares minimizer
Austin Schuh3de38b02024-06-25 18:25:10 -07002// Copyright 2023 Google Inc. All rights reserved.
Austin Schuh70cc9552019-01-21 19:46:48 -08003// http://ceres-solver.org/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
Austin Schuh3de38b02024-06-25 18:25:10 -070029// Authors: vitus@google.com (Michael Vitus),
30// dmitriy.korchemkin@gmail.com (Dmitriy Korchemkin)
Austin Schuh70cc9552019-01-21 19:46:48 -080031
Austin Schuh3de38b02024-06-25 18:25:10 -070032#ifndef CERES_INTERNAL_PARALLEL_FOR_H_
33#define CERES_INTERNAL_PARALLEL_FOR_H_
Austin Schuh70cc9552019-01-21 19:46:48 -080034
Austin Schuh3de38b02024-06-25 18:25:10 -070035#include <mutex>
36#include <vector>
Austin Schuh70cc9552019-01-21 19:46:48 -080037
38#include "ceres/context_impl.h"
Austin Schuh3de38b02024-06-25 18:25:10 -070039#include "ceres/internal/eigen.h"
40#include "ceres/internal/export.h"
41#include "ceres/parallel_invoke.h"
42#include "ceres/partition_range_for_parallel_for.h"
43#include "glog/logging.h"
Austin Schuh70cc9552019-01-21 19:46:48 -080044
Austin Schuh3de38b02024-06-25 18:25:10 -070045namespace ceres::internal {
Austin Schuh70cc9552019-01-21 19:46:48 -080046
Austin Schuh3de38b02024-06-25 18:25:10 -070047// Use a dummy mutex if num_threads = 1.
48inline decltype(auto) MakeConditionalLock(const int num_threads,
49 std::mutex& m) {
50 return (num_threads == 1) ? std::unique_lock<std::mutex>{}
51 : std::unique_lock<std::mutex>{m};
52}
Austin Schuh70cc9552019-01-21 19:46:48 -080053
54// Execute the function for every element in the range [start, end) with at most
55// num_threads. It will execute all the work on the calling thread if
Austin Schuh3de38b02024-06-25 18:25:10 -070056// num_threads or (end - start) is equal to 1.
57// Depending on function signature, it will be supplied with either loop index
58// or a range of loop indicies; function can also be supplied with thread_id.
59// The following function signatures are supported:
60// - Functions accepting a single loop index:
61// - [](int index) { ... }
62// - [](int thread_id, int index) { ... }
63// - Functions accepting a range of loop index:
64// - [](std::tuple<int, int> index) { ... }
65// - [](int thread_id, std::tuple<int, int> index) { ... }
66//
67// When distributing workload between threads, it is assumed that each loop
68// iteration takes approximately equal time to complete.
69template <typename F>
70void ParallelFor(ContextImpl* context,
71 int start,
72 int end,
73 int num_threads,
74 F&& function,
75 int min_block_size = 1) {
76 CHECK_GT(num_threads, 0);
77 if (start >= end) {
78 return;
79 }
Austin Schuh70cc9552019-01-21 19:46:48 -080080
Austin Schuh3de38b02024-06-25 18:25:10 -070081 if (num_threads == 1 || end - start < min_block_size * 2) {
82 InvokeOnSegment(0, std::make_tuple(start, end), std::forward<F>(function));
83 return;
84 }
85
86 CHECK(context != nullptr);
87 ParallelInvoke(context,
88 start,
89 end,
90 num_threads,
91 std::forward<F>(function),
92 min_block_size);
93}
94
95// Execute function for every element in the range [start, end) with at most
96// num_threads, using user-provided partitions array.
97// When distributing workload between threads, it is assumed that each segment
98// bounded by adjacent elements of partitions array takes approximately equal
99// time to process.
100template <typename F>
101void ParallelFor(ContextImpl* context,
102 int start,
103 int end,
104 int num_threads,
105 F&& function,
106 const std::vector<int>& partitions) {
107 CHECK_GT(num_threads, 0);
108 if (start >= end) {
109 return;
110 }
111 CHECK_EQ(partitions.front(), start);
112 CHECK_EQ(partitions.back(), end);
113 if (num_threads == 1 || end - start <= num_threads) {
114 ParallelFor(context, start, end, num_threads, std::forward<F>(function));
115 return;
116 }
117 CHECK_GT(partitions.size(), 1);
118 const int num_partitions = partitions.size() - 1;
119 ParallelFor(context,
120 0,
121 num_partitions,
122 num_threads,
123 [&function, &partitions](int thread_id,
124 std::tuple<int, int> partition_ids) {
125 // partition_ids is a range of partition indices
126 const auto [partition_start, partition_end] = partition_ids;
127 // Execution over several adjacent segments is equivalent
128 // to execution over union of those segments (which is also a
129 // contiguous segment)
130 const int range_start = partitions[partition_start];
131 const int range_end = partitions[partition_end];
132 // Range of original loop indices
133 const auto range = std::make_tuple(range_start, range_end);
134 InvokeOnSegment(thread_id, range, function);
135 });
136}
137
138// Execute function for every element in the range [start, end) with at most
139// num_threads, taking into account user-provided integer cumulative costs of
140// iterations. Cumulative costs of iteration for indices in range [0, end) are
141// stored in objects from cumulative_cost_data. User-provided
142// cumulative_cost_fun returns non-decreasing integer values corresponding to
143// inclusive cumulative cost of loop iterations, provided with a reference to
144// user-defined object. Only indices from [start, end) will be referenced. This
145// routine assumes that cumulative_cost_fun is non-decreasing (in other words,
146// all costs are non-negative);
147// When distributing workload between threads, input range of loop indices will
148// be partitioned into disjoint contiguous intervals, with the maximal cost
149// being minimized.
150// For example, with iteration costs of [1, 1, 5, 3, 1, 4] cumulative_cost_fun
151// should return [1, 2, 7, 10, 11, 15], and with num_threads = 4 this range
152// will be split into segments [0, 2) [2, 3) [3, 5) [5, 6) with costs
153// [2, 5, 4, 4].
154template <typename F, typename CumulativeCostData, typename CumulativeCostFun>
155void ParallelFor(ContextImpl* context,
156 int start,
157 int end,
158 int num_threads,
159 F&& function,
160 const CumulativeCostData* cumulative_cost_data,
161 CumulativeCostFun&& cumulative_cost_fun) {
162 CHECK_GT(num_threads, 0);
163 if (start >= end) {
164 return;
165 }
166 if (num_threads == 1 || end - start <= num_threads) {
167 ParallelFor(context, start, end, num_threads, std::forward<F>(function));
168 return;
169 }
170 // Creating several partitions allows us to tolerate imperfections of
171 // partitioning and user-supplied iteration costs up to a certain extent
172 constexpr int kNumPartitionsPerThread = 4;
173 const int kMaxPartitions = num_threads * kNumPartitionsPerThread;
174 const auto& partitions = PartitionRangeForParallelFor(
175 start,
176 end,
177 kMaxPartitions,
178 cumulative_cost_data,
179 std::forward<CumulativeCostFun>(cumulative_cost_fun));
180 CHECK_GT(partitions.size(), 1);
181 ParallelFor(
182 context, start, end, num_threads, std::forward<F>(function), partitions);
183}
184} // namespace ceres::internal
Austin Schuh70cc9552019-01-21 19:46:48 -0800185
186#endif // CERES_INTERNAL_PARALLEL_FOR_H_