blob: 52575eec892a5e724b3d83911b2c7c494dbf948e [file] [log] [blame]
Austin Schuh70cc9552019-01-21 19:46:48 -08001// Ceres Solver - A fast non-linear least squares minimizer
Austin Schuh3de38b02024-06-25 18:25:10 -07002// Copyright 2023 Google Inc. All rights reserved.
Austin Schuh70cc9552019-01-21 19:46:48 -08003// http://ceres-solver.org/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Author: sameeragarwal@google.com (Sameer Agarwal)
30// mierle@gmail.com (Keir Mierle)
31
32#include "ceres/problem_impl.h"
33
34#include <algorithm>
35#include <cstddef>
36#include <cstdint>
37#include <iterator>
38#include <memory>
39#include <set>
40#include <string>
41#include <utility>
42#include <vector>
43
44#include "ceres/casts.h"
45#include "ceres/compressed_row_jacobian_writer.h"
46#include "ceres/compressed_row_sparse_matrix.h"
47#include "ceres/context_impl.h"
48#include "ceres/cost_function.h"
49#include "ceres/crs_matrix.h"
Austin Schuh1d1e6ea2020-12-23 21:56:30 -080050#include "ceres/evaluation_callback.h"
Austin Schuh70cc9552019-01-21 19:46:48 -080051#include "ceres/evaluator.h"
Austin Schuh3de38b02024-06-25 18:25:10 -070052#include "ceres/internal/export.h"
Austin Schuh1d1e6ea2020-12-23 21:56:30 -080053#include "ceres/internal/fixed_array.h"
Austin Schuh70cc9552019-01-21 19:46:48 -080054#include "ceres/loss_function.h"
Austin Schuh3de38b02024-06-25 18:25:10 -070055#include "ceres/manifold.h"
Austin Schuh70cc9552019-01-21 19:46:48 -080056#include "ceres/map_util.h"
57#include "ceres/parameter_block.h"
58#include "ceres/program.h"
59#include "ceres/program_evaluator.h"
60#include "ceres/residual_block.h"
61#include "ceres/scratch_evaluate_preparer.h"
62#include "ceres/stl_util.h"
63#include "ceres/stringprintf.h"
64#include "glog/logging.h"
65
Austin Schuh3de38b02024-06-25 18:25:10 -070066namespace ceres::internal {
Austin Schuh70cc9552019-01-21 19:46:48 -080067namespace {
68// Returns true if two regions of memory, a and b, with sizes size_a and size_b
69// respectively, overlap.
Austin Schuh1d1e6ea2020-12-23 21:56:30 -080070bool RegionsAlias(const double* a, int size_a, const double* b, int size_b) {
71 return (a < b) ? b < (a + size_a) : a < (b + size_b);
Austin Schuh70cc9552019-01-21 19:46:48 -080072}
73
74void CheckForNoAliasing(double* existing_block,
75 int existing_block_size,
76 double* new_block,
77 int new_block_size) {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -080078 CHECK(!RegionsAlias(
79 existing_block, existing_block_size, new_block, new_block_size))
Austin Schuh70cc9552019-01-21 19:46:48 -080080 << "Aliasing detected between existing parameter block at memory "
Austin Schuh1d1e6ea2020-12-23 21:56:30 -080081 << "location " << existing_block << " and has size "
82 << existing_block_size << " with new parameter "
Austin Schuh70cc9552019-01-21 19:46:48 -080083 << "block that has memory address " << new_block << " and would have "
84 << "size " << new_block_size << ".";
85}
86
87template <typename KeyType>
88void DecrementValueOrDeleteKey(const KeyType key,
89 std::map<KeyType, int>* container) {
90 auto it = container->find(key);
91 if (it->second == 1) {
92 delete key;
93 container->erase(it);
94 } else {
95 --it->second;
96 }
97}
98
99template <typename ForwardIterator>
100void STLDeleteContainerPairFirstPointers(ForwardIterator begin,
101 ForwardIterator end) {
102 while (begin != end) {
103 delete begin->first;
104 ++begin;
105 }
106}
107
108void InitializeContext(Context* context,
109 ContextImpl** context_impl,
110 bool* context_impl_owned) {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800111 if (context == nullptr) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800112 *context_impl_owned = true;
113 *context_impl = new ContextImpl;
114 } else {
115 *context_impl_owned = false;
116 *context_impl = down_cast<ContextImpl*>(context);
117 }
118}
119
120} // namespace
121
122ParameterBlock* ProblemImpl::InternalAddParameterBlock(double* values,
123 int size) {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800124 CHECK(values != nullptr) << "Null pointer passed to AddParameterBlock "
125 << "for a parameter with size " << size;
Austin Schuh70cc9552019-01-21 19:46:48 -0800126
127 // Ignore the request if there is a block for the given pointer already.
Austin Schuh3de38b02024-06-25 18:25:10 -0700128 auto it = parameter_block_map_.find(values);
Austin Schuh70cc9552019-01-21 19:46:48 -0800129 if (it != parameter_block_map_.end()) {
130 if (!options_.disable_all_safety_checks) {
131 int existing_size = it->second->Size();
132 CHECK(size == existing_size)
133 << "Tried adding a parameter block with the same double pointer, "
134 << values << ", twice, but with different block sizes. Original "
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800135 << "size was " << existing_size << " but new size is " << size;
Austin Schuh70cc9552019-01-21 19:46:48 -0800136 }
137 return it->second;
138 }
139
140 if (!options_.disable_all_safety_checks) {
141 // Before adding the parameter block, also check that it doesn't alias any
142 // other parameter blocks.
143 if (!parameter_block_map_.empty()) {
Austin Schuh3de38b02024-06-25 18:25:10 -0700144 auto lb = parameter_block_map_.lower_bound(values);
Austin Schuh70cc9552019-01-21 19:46:48 -0800145
146 // If lb is not the first block, check the previous block for aliasing.
147 if (lb != parameter_block_map_.begin()) {
Austin Schuh3de38b02024-06-25 18:25:10 -0700148 auto previous = lb;
Austin Schuh70cc9552019-01-21 19:46:48 -0800149 --previous;
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800150 CheckForNoAliasing(
151 previous->first, previous->second->Size(), values, size);
Austin Schuh70cc9552019-01-21 19:46:48 -0800152 }
153
154 // If lb is not off the end, check lb for aliasing.
155 if (lb != parameter_block_map_.end()) {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800156 CheckForNoAliasing(lb->first, lb->second->Size(), values, size);
Austin Schuh70cc9552019-01-21 19:46:48 -0800157 }
158 }
159 }
160
161 // Pass the index of the new parameter block as well to keep the index in
162 // sync with the position of the parameter in the program's parameter vector.
Austin Schuh3de38b02024-06-25 18:25:10 -0700163 auto* new_parameter_block =
Austin Schuh70cc9552019-01-21 19:46:48 -0800164 new ParameterBlock(values, size, program_->parameter_blocks_.size());
165
166 // For dynamic problems, add the list of dependent residual blocks, which is
167 // empty to start.
168 if (options_.enable_fast_removal) {
169 new_parameter_block->EnableResidualBlockDependencies();
170 }
171 parameter_block_map_[values] = new_parameter_block;
172 program_->parameter_blocks_.push_back(new_parameter_block);
173 return new_parameter_block;
174}
175
176void ProblemImpl::InternalRemoveResidualBlock(ResidualBlock* residual_block) {
177 CHECK(residual_block != nullptr);
178 // Perform no check on the validity of residual_block, that is handled in
179 // the public method: RemoveResidualBlock().
180
181 // If needed, remove the parameter dependencies on this residual block.
182 if (options_.enable_fast_removal) {
183 const int num_parameter_blocks_for_residual =
184 residual_block->NumParameterBlocks();
185 for (int i = 0; i < num_parameter_blocks_for_residual; ++i) {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800186 residual_block->parameter_blocks()[i]->RemoveResidualBlock(
187 residual_block);
Austin Schuh70cc9552019-01-21 19:46:48 -0800188 }
189
Austin Schuh3de38b02024-06-25 18:25:10 -0700190 auto it = residual_block_set_.find(residual_block);
Austin Schuh70cc9552019-01-21 19:46:48 -0800191 residual_block_set_.erase(it);
192 }
193 DeleteBlockInVector(program_->mutable_residual_blocks(), residual_block);
194}
195
196// Deletes the residual block in question, assuming there are no other
197// references to it inside the problem (e.g. by another parameter). Referenced
198// cost and loss functions are tucked away for future deletion, since it is not
199// possible to know whether other parts of the problem depend on them without
200// doing a full scan.
201void ProblemImpl::DeleteBlock(ResidualBlock* residual_block) {
202 // The const casts here are legit, since ResidualBlock holds these
203 // pointers as const pointers but we have ownership of them and
204 // have the right to destroy them when the destructor is called.
Austin Schuh3de38b02024-06-25 18:25:10 -0700205 auto* cost_function =
Austin Schuh70cc9552019-01-21 19:46:48 -0800206 const_cast<CostFunction*>(residual_block->cost_function());
207 if (options_.cost_function_ownership == TAKE_OWNERSHIP) {
208 DecrementValueOrDeleteKey(cost_function, &cost_function_ref_count_);
209 }
210
Austin Schuh3de38b02024-06-25 18:25:10 -0700211 auto* loss_function =
Austin Schuh70cc9552019-01-21 19:46:48 -0800212 const_cast<LossFunction*>(residual_block->loss_function());
213 if (options_.loss_function_ownership == TAKE_OWNERSHIP &&
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800214 loss_function != nullptr) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800215 DecrementValueOrDeleteKey(loss_function, &loss_function_ref_count_);
216 }
217
218 delete residual_block;
219}
220
221// Deletes the parameter block in question, assuming there are no other
222// references to it inside the problem (e.g. by any residual blocks).
Austin Schuh70cc9552019-01-21 19:46:48 -0800223void ProblemImpl::DeleteBlock(ParameterBlock* parameter_block) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800224 parameter_block_map_.erase(parameter_block->mutable_user_state());
225 delete parameter_block;
226}
227
228ProblemImpl::ProblemImpl()
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800229 : options_(Problem::Options()), program_(new internal::Program) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800230 InitializeContext(options_.context, &context_impl_, &context_impl_owned_);
231}
232
233ProblemImpl::ProblemImpl(const Problem::Options& options)
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800234 : options_(options), program_(new internal::Program) {
235 program_->evaluation_callback_ = options.evaluation_callback;
Austin Schuh70cc9552019-01-21 19:46:48 -0800236 InitializeContext(options_.context, &context_impl_, &context_impl_owned_);
237}
238
239ProblemImpl::~ProblemImpl() {
240 STLDeleteContainerPointers(program_->residual_blocks_.begin(),
241 program_->residual_blocks_.end());
242
243 if (options_.cost_function_ownership == TAKE_OWNERSHIP) {
244 STLDeleteContainerPairFirstPointers(cost_function_ref_count_.begin(),
245 cost_function_ref_count_.end());
246 }
247
248 if (options_.loss_function_ownership == TAKE_OWNERSHIP) {
249 STLDeleteContainerPairFirstPointers(loss_function_ref_count_.begin(),
250 loss_function_ref_count_.end());
251 }
252
253 // Collect the unique parameterizations and delete the parameters.
Austin Schuh3de38b02024-06-25 18:25:10 -0700254 for (auto* parameter_block : program_->parameter_blocks_) {
255 DeleteBlock(parameter_block);
Austin Schuh70cc9552019-01-21 19:46:48 -0800256 }
257
Austin Schuh3de38b02024-06-25 18:25:10 -0700258 // Delete the owned manifolds.
259 STLDeleteUniqueContainerPointers(manifolds_to_delete_.begin(),
260 manifolds_to_delete_.end());
Austin Schuh70cc9552019-01-21 19:46:48 -0800261
262 if (context_impl_owned_) {
263 delete context_impl_;
264 }
265}
266
267ResidualBlockId ProblemImpl::AddResidualBlock(
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800268 CostFunction* cost_function,
269 LossFunction* loss_function,
270 double* const* const parameter_blocks,
271 int num_parameter_blocks) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800272 CHECK(cost_function != nullptr);
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800273 CHECK_EQ(num_parameter_blocks, cost_function->parameter_block_sizes().size());
Austin Schuh70cc9552019-01-21 19:46:48 -0800274
275 // Check the sizes match.
Austin Schuh3de38b02024-06-25 18:25:10 -0700276 const std::vector<int32_t>& parameter_block_sizes =
Austin Schuh70cc9552019-01-21 19:46:48 -0800277 cost_function->parameter_block_sizes();
278
279 if (!options_.disable_all_safety_checks) {
280 CHECK_EQ(parameter_block_sizes.size(), num_parameter_blocks)
281 << "Number of blocks input is different than the number of blocks "
282 << "that the cost function expects.";
283
284 // Check for duplicate parameter blocks.
Austin Schuh3de38b02024-06-25 18:25:10 -0700285 std::vector<double*> sorted_parameter_blocks(
Austin Schuh70cc9552019-01-21 19:46:48 -0800286 parameter_blocks, parameter_blocks + num_parameter_blocks);
Austin Schuh3de38b02024-06-25 18:25:10 -0700287 std::sort(sorted_parameter_blocks.begin(), sorted_parameter_blocks.end());
Austin Schuh70cc9552019-01-21 19:46:48 -0800288 const bool has_duplicate_items =
289 (std::adjacent_find(sorted_parameter_blocks.begin(),
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800290 sorted_parameter_blocks.end()) !=
291 sorted_parameter_blocks.end());
Austin Schuh70cc9552019-01-21 19:46:48 -0800292 if (has_duplicate_items) {
Austin Schuh3de38b02024-06-25 18:25:10 -0700293 std::string blocks;
Austin Schuh70cc9552019-01-21 19:46:48 -0800294 for (int i = 0; i < num_parameter_blocks; ++i) {
295 blocks += StringPrintf(" %p ", parameter_blocks[i]);
296 }
297
298 LOG(FATAL) << "Duplicate parameter blocks in a residual parameter "
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800299 << "are not allowed. Parameter block pointers: [" << blocks
300 << "]";
Austin Schuh70cc9552019-01-21 19:46:48 -0800301 }
302 }
303
304 // Add parameter blocks and convert the double*'s to parameter blocks.
Austin Schuh3de38b02024-06-25 18:25:10 -0700305 std::vector<ParameterBlock*> parameter_block_ptrs(num_parameter_blocks);
Austin Schuh70cc9552019-01-21 19:46:48 -0800306 for (int i = 0; i < num_parameter_blocks; ++i) {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800307 parameter_block_ptrs[i] = InternalAddParameterBlock(
308 parameter_blocks[i], parameter_block_sizes[i]);
Austin Schuh70cc9552019-01-21 19:46:48 -0800309 }
310
311 if (!options_.disable_all_safety_checks) {
312 // Check that the block sizes match the block sizes expected by the
313 // cost_function.
314 for (int i = 0; i < parameter_block_ptrs.size(); ++i) {
315 CHECK_EQ(cost_function->parameter_block_sizes()[i],
316 parameter_block_ptrs[i]->Size())
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800317 << "The cost function expects parameter block " << i << " of size "
318 << cost_function->parameter_block_sizes()[i]
Austin Schuh70cc9552019-01-21 19:46:48 -0800319 << " but was given a block of size "
320 << parameter_block_ptrs[i]->Size();
321 }
322 }
323
Austin Schuh3de38b02024-06-25 18:25:10 -0700324 auto* new_residual_block =
Austin Schuh70cc9552019-01-21 19:46:48 -0800325 new ResidualBlock(cost_function,
326 loss_function,
327 parameter_block_ptrs,
328 program_->residual_blocks_.size());
329
330 // Add dependencies on the residual to the parameter blocks.
331 if (options_.enable_fast_removal) {
332 for (int i = 0; i < num_parameter_blocks; ++i) {
333 parameter_block_ptrs[i]->AddResidualBlock(new_residual_block);
334 }
335 }
336
337 program_->residual_blocks_.push_back(new_residual_block);
338
339 if (options_.enable_fast_removal) {
340 residual_block_set_.insert(new_residual_block);
341 }
342
343 if (options_.cost_function_ownership == TAKE_OWNERSHIP) {
344 // Increment the reference count, creating an entry in the table if
345 // needed. Note: C++ maps guarantee that new entries have default
346 // constructed values; this implies integers are zero initialized.
347 ++cost_function_ref_count_[cost_function];
348 }
349
350 if (options_.loss_function_ownership == TAKE_OWNERSHIP &&
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800351 loss_function != nullptr) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800352 ++loss_function_ref_count_[loss_function];
353 }
354
355 return new_residual_block;
356}
357
358void ProblemImpl::AddParameterBlock(double* values, int size) {
359 InternalAddParameterBlock(values, size);
360}
361
Austin Schuh3de38b02024-06-25 18:25:10 -0700362void ProblemImpl::InternalSetManifold(double* /*values*/,
363 ParameterBlock* parameter_block,
364 Manifold* manifold) {
365 if (manifold != nullptr && options_.manifold_ownership == TAKE_OWNERSHIP) {
366 manifolds_to_delete_.push_back(manifold);
Austin Schuh70cc9552019-01-21 19:46:48 -0800367 }
Austin Schuh3de38b02024-06-25 18:25:10 -0700368 parameter_block->SetManifold(manifold);
369}
370
371void ProblemImpl::AddParameterBlock(double* values,
372 int size,
373 Manifold* manifold) {
374 ParameterBlock* parameter_block = InternalAddParameterBlock(values, size);
375 InternalSetManifold(values, parameter_block, manifold);
Austin Schuh70cc9552019-01-21 19:46:48 -0800376}
377
378// Delete a block from a vector of blocks, maintaining the indexing invariant.
379// This is done in constant time by moving an element from the end of the
380// vector over the element to remove, then popping the last element. It
381// destroys the ordering in the interest of speed.
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800382template <typename Block>
Austin Schuh3de38b02024-06-25 18:25:10 -0700383void ProblemImpl::DeleteBlockInVector(std::vector<Block*>* mutable_blocks,
Austin Schuh70cc9552019-01-21 19:46:48 -0800384 Block* block_to_remove) {
385 CHECK_EQ((*mutable_blocks)[block_to_remove->index()], block_to_remove)
386 << "You found a Ceres bug! \n"
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800387 << "Block requested: " << block_to_remove->ToString() << "\n"
Austin Schuh70cc9552019-01-21 19:46:48 -0800388 << "Block present: "
389 << (*mutable_blocks)[block_to_remove->index()]->ToString();
390
391 // Prepare the to-be-moved block for the new, lower-in-index position by
392 // setting the index to the blocks final location.
393 Block* tmp = mutable_blocks->back();
394 tmp->set_index(block_to_remove->index());
395
396 // Overwrite the to-be-deleted residual block with the one at the end.
397 (*mutable_blocks)[block_to_remove->index()] = tmp;
398
399 DeleteBlock(block_to_remove);
400
401 // The block is gone so shrink the vector of blocks accordingly.
402 mutable_blocks->pop_back();
403}
404
405void ProblemImpl::RemoveResidualBlock(ResidualBlock* residual_block) {
406 CHECK(residual_block != nullptr);
407
408 // Verify that residual_block identifies a residual in the current problem.
Austin Schuh3de38b02024-06-25 18:25:10 -0700409 const std::string residual_not_found_message = StringPrintf(
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800410 "Residual block to remove: %p not found. This usually means "
411 "one of three things have happened:\n"
412 " 1) residual_block is uninitialised and points to a random "
413 "area in memory.\n"
414 " 2) residual_block represented a residual that was added to"
415 " the problem, but referred to a parameter block which has "
416 "since been removed, which removes all residuals which "
417 "depend on that parameter block, and was thus removed.\n"
418 " 3) residual_block referred to a residual that has already "
419 "been removed from the problem (by the user).",
420 residual_block);
Austin Schuh70cc9552019-01-21 19:46:48 -0800421 if (options_.enable_fast_removal) {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800422 CHECK(residual_block_set_.find(residual_block) != residual_block_set_.end())
Austin Schuh70cc9552019-01-21 19:46:48 -0800423 << residual_not_found_message;
424 } else {
425 // Perform a full search over all current residuals.
426 CHECK(std::find(program_->residual_blocks().begin(),
427 program_->residual_blocks().end(),
428 residual_block) != program_->residual_blocks().end())
429 << residual_not_found_message;
430 }
431
432 InternalRemoveResidualBlock(residual_block);
433}
434
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800435void ProblemImpl::RemoveParameterBlock(const double* values) {
436 ParameterBlock* parameter_block = FindWithDefault(
437 parameter_block_map_, const_cast<double*>(values), nullptr);
438 if (parameter_block == nullptr) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800439 LOG(FATAL) << "Parameter block not found: " << values
440 << ". You must add the parameter block to the problem before "
441 << "it can be removed.";
442 }
443
444 if (options_.enable_fast_removal) {
445 // Copy the dependent residuals from the parameter block because the set of
446 // dependents will change after each call to RemoveResidualBlock().
Austin Schuh3de38b02024-06-25 18:25:10 -0700447 std::vector<ResidualBlock*> residual_blocks_to_remove(
Austin Schuh70cc9552019-01-21 19:46:48 -0800448 parameter_block->mutable_residual_blocks()->begin(),
449 parameter_block->mutable_residual_blocks()->end());
Austin Schuh3de38b02024-06-25 18:25:10 -0700450 for (auto* residual_block : residual_blocks_to_remove) {
451 InternalRemoveResidualBlock(residual_block);
Austin Schuh70cc9552019-01-21 19:46:48 -0800452 }
453 } else {
454 // Scan all the residual blocks to remove ones that depend on the parameter
455 // block. Do the scan backwards since the vector changes while iterating.
456 const int num_residual_blocks = NumResidualBlocks();
457 for (int i = num_residual_blocks - 1; i >= 0; --i) {
458 ResidualBlock* residual_block =
459 (*(program_->mutable_residual_blocks()))[i];
460 const int num_parameter_blocks = residual_block->NumParameterBlocks();
461 for (int j = 0; j < num_parameter_blocks; ++j) {
462 if (residual_block->parameter_blocks()[j] == parameter_block) {
463 InternalRemoveResidualBlock(residual_block);
464 // The parameter blocks are guaranteed unique.
465 break;
466 }
467 }
468 }
469 }
470 DeleteBlockInVector(program_->mutable_parameter_blocks(), parameter_block);
471}
472
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800473void ProblemImpl::SetParameterBlockConstant(const double* values) {
474 ParameterBlock* parameter_block = FindWithDefault(
475 parameter_block_map_, const_cast<double*>(values), nullptr);
476 if (parameter_block == nullptr) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800477 LOG(FATAL) << "Parameter block not found: " << values
478 << ". You must add the parameter block to the problem before "
479 << "it can be set constant.";
480 }
481
482 parameter_block->SetConstant();
483}
484
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800485bool ProblemImpl::IsParameterBlockConstant(const double* values) const {
486 const ParameterBlock* parameter_block = FindWithDefault(
487 parameter_block_map_, const_cast<double*>(values), nullptr);
488 CHECK(parameter_block != nullptr)
489 << "Parameter block not found: " << values << ". You must add the "
490 << "parameter block to the problem before it can be queried.";
491 return parameter_block->IsConstant();
Austin Schuh70cc9552019-01-21 19:46:48 -0800492}
493
494void ProblemImpl::SetParameterBlockVariable(double* values) {
495 ParameterBlock* parameter_block =
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800496 FindWithDefault(parameter_block_map_, values, nullptr);
497 if (parameter_block == nullptr) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800498 LOG(FATAL) << "Parameter block not found: " << values
499 << ". You must add the parameter block to the problem before "
500 << "it can be set varying.";
501 }
502
503 parameter_block->SetVarying();
504}
505
Austin Schuh3de38b02024-06-25 18:25:10 -0700506void ProblemImpl::SetManifold(double* values, Manifold* manifold) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800507 ParameterBlock* parameter_block =
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800508 FindWithDefault(parameter_block_map_, values, nullptr);
509 if (parameter_block == nullptr) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800510 LOG(FATAL) << "Parameter block not found: " << values
511 << ". You must add the parameter block to the problem before "
Austin Schuh3de38b02024-06-25 18:25:10 -0700512 << "you can set its manifold.";
Austin Schuh70cc9552019-01-21 19:46:48 -0800513 }
514
Austin Schuh3de38b02024-06-25 18:25:10 -0700515 InternalSetManifold(values, parameter_block, manifold);
Austin Schuh70cc9552019-01-21 19:46:48 -0800516}
517
Austin Schuh3de38b02024-06-25 18:25:10 -0700518const Manifold* ProblemImpl::GetManifold(const double* values) const {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800519 ParameterBlock* parameter_block = FindWithDefault(
520 parameter_block_map_, const_cast<double*>(values), nullptr);
521 if (parameter_block == nullptr) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800522 LOG(FATAL) << "Parameter block not found: " << values
523 << ". You must add the parameter block to the problem before "
Austin Schuh3de38b02024-06-25 18:25:10 -0700524 << "you can get its manifold.";
Austin Schuh70cc9552019-01-21 19:46:48 -0800525 }
526
Austin Schuh3de38b02024-06-25 18:25:10 -0700527 return parameter_block->manifold();
528}
529
530bool ProblemImpl::HasManifold(const double* values) const {
531 return GetManifold(values) != nullptr;
Austin Schuh70cc9552019-01-21 19:46:48 -0800532}
533
534void ProblemImpl::SetParameterLowerBound(double* values,
535 int index,
536 double lower_bound) {
537 ParameterBlock* parameter_block =
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800538 FindWithDefault(parameter_block_map_, values, nullptr);
539 if (parameter_block == nullptr) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800540 LOG(FATAL) << "Parameter block not found: " << values
541 << ". You must add the parameter block to the problem before "
542 << "you can set a lower bound on one of its components.";
543 }
544
545 parameter_block->SetLowerBound(index, lower_bound);
546}
547
548void ProblemImpl::SetParameterUpperBound(double* values,
549 int index,
550 double upper_bound) {
551 ParameterBlock* parameter_block =
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800552 FindWithDefault(parameter_block_map_, values, nullptr);
553 if (parameter_block == nullptr) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800554 LOG(FATAL) << "Parameter block not found: " << values
555 << ". You must add the parameter block to the problem before "
556 << "you can set an upper bound on one of its components.";
557 }
558 parameter_block->SetUpperBound(index, upper_bound);
559}
560
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800561double ProblemImpl::GetParameterLowerBound(const double* values,
562 int index) const {
563 ParameterBlock* parameter_block = FindWithDefault(
564 parameter_block_map_, const_cast<double*>(values), nullptr);
565 if (parameter_block == nullptr) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800566 LOG(FATAL) << "Parameter block not found: " << values
567 << ". You must add the parameter block to the problem before "
568 << "you can get the lower bound on one of its components.";
569 }
570 return parameter_block->LowerBound(index);
571}
572
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800573double ProblemImpl::GetParameterUpperBound(const double* values,
574 int index) const {
575 ParameterBlock* parameter_block = FindWithDefault(
576 parameter_block_map_, const_cast<double*>(values), nullptr);
577 if (parameter_block == nullptr) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800578 LOG(FATAL) << "Parameter block not found: " << values
579 << ". You must add the parameter block to the problem before "
580 << "you can set an upper bound on one of its components.";
581 }
582 return parameter_block->UpperBound(index);
583}
584
585bool ProblemImpl::Evaluate(const Problem::EvaluateOptions& evaluate_options,
586 double* cost,
Austin Schuh3de38b02024-06-25 18:25:10 -0700587 std::vector<double>* residuals,
588 std::vector<double>* gradient,
Austin Schuh70cc9552019-01-21 19:46:48 -0800589 CRSMatrix* jacobian) {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800590 if (cost == nullptr && residuals == nullptr && gradient == nullptr &&
591 jacobian == nullptr) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800592 return true;
593 }
594
595 // If the user supplied residual blocks, then use them, otherwise
596 // take the residual blocks from the underlying program.
597 Program program;
598 *program.mutable_residual_blocks() =
599 ((evaluate_options.residual_blocks.size() > 0)
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800600 ? evaluate_options.residual_blocks
601 : program_->residual_blocks());
Austin Schuh70cc9552019-01-21 19:46:48 -0800602
Austin Schuh3de38b02024-06-25 18:25:10 -0700603 const std::vector<double*>& parameter_block_ptrs =
Austin Schuh70cc9552019-01-21 19:46:48 -0800604 evaluate_options.parameter_blocks;
605
Austin Schuh3de38b02024-06-25 18:25:10 -0700606 std::vector<ParameterBlock*> variable_parameter_blocks;
607 std::vector<ParameterBlock*>& parameter_blocks =
Austin Schuh70cc9552019-01-21 19:46:48 -0800608 *program.mutable_parameter_blocks();
609
610 if (parameter_block_ptrs.size() == 0) {
611 // The user did not provide any parameter blocks, so default to
612 // using all the parameter blocks in the order that they are in
613 // the underlying program object.
614 parameter_blocks = program_->parameter_blocks();
615 } else {
616 // The user supplied a vector of parameter blocks. Using this list
617 // requires a number of steps.
618
619 // 1. Convert double* into ParameterBlock*
620 parameter_blocks.resize(parameter_block_ptrs.size());
621 for (int i = 0; i < parameter_block_ptrs.size(); ++i) {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800622 parameter_blocks[i] = FindWithDefault(
623 parameter_block_map_, parameter_block_ptrs[i], nullptr);
624 if (parameter_blocks[i] == nullptr) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800625 LOG(FATAL) << "No known parameter block for "
626 << "Problem::Evaluate::Options.parameter_blocks[" << i << "]"
627 << " = " << parameter_block_ptrs[i];
628 }
629 }
630
631 // 2. The user may have only supplied a subset of parameter
632 // blocks, so identify the ones that are not supplied by the user
633 // and are NOT constant. These parameter blocks are stored in
634 // variable_parameter_blocks.
635 //
636 // To ensure that the parameter blocks are not included in the
637 // columns of the jacobian, we need to make sure that they are
638 // constant during evaluation and then make them variable again
639 // after we are done.
Austin Schuh3de38b02024-06-25 18:25:10 -0700640 std::vector<ParameterBlock*> all_parameter_blocks(
641 program_->parameter_blocks());
642 std::vector<ParameterBlock*> included_parameter_blocks(
Austin Schuh70cc9552019-01-21 19:46:48 -0800643 program.parameter_blocks());
644
Austin Schuh3de38b02024-06-25 18:25:10 -0700645 std::vector<ParameterBlock*> excluded_parameter_blocks;
646 std::sort(all_parameter_blocks.begin(), all_parameter_blocks.end());
647 std::sort(included_parameter_blocks.begin(),
648 included_parameter_blocks.end());
Austin Schuh70cc9552019-01-21 19:46:48 -0800649 set_difference(all_parameter_blocks.begin(),
650 all_parameter_blocks.end(),
651 included_parameter_blocks.begin(),
652 included_parameter_blocks.end(),
653 back_inserter(excluded_parameter_blocks));
654
655 variable_parameter_blocks.reserve(excluded_parameter_blocks.size());
Austin Schuh3de38b02024-06-25 18:25:10 -0700656 for (auto* parameter_block : excluded_parameter_blocks) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800657 if (!parameter_block->IsConstant()) {
658 variable_parameter_blocks.push_back(parameter_block);
659 parameter_block->SetConstant();
660 }
661 }
662 }
663
664 // Setup the Parameter indices and offsets before an evaluator can
665 // be constructed and used.
666 program.SetParameterOffsetsAndIndex();
667
668 Evaluator::Options evaluator_options;
669
Austin Schuh3de38b02024-06-25 18:25:10 -0700670 // Even though using SPARSE_NORMAL_CHOLESKY requires a sparse linear algebra
671 // library here it just being used for telling the evaluator to use a
672 // SparseRowCompressedMatrix for the jacobian. This is because the Evaluator
673 // decides the storage for the Jacobian based on the type of linear solver
674 // being used.
Austin Schuh70cc9552019-01-21 19:46:48 -0800675 evaluator_options.linear_solver_type = SPARSE_NORMAL_CHOLESKY;
Austin Schuh70cc9552019-01-21 19:46:48 -0800676 evaluator_options.num_threads = evaluate_options.num_threads;
Austin Schuh70cc9552019-01-21 19:46:48 -0800677
678 // The main thread also does work so we only need to launch num_threads - 1.
679 context_impl_->EnsureMinimumThreads(evaluator_options.num_threads - 1);
680 evaluator_options.context = context_impl_;
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800681 evaluator_options.evaluation_callback =
682 program_->mutable_evaluation_callback();
Austin Schuh70cc9552019-01-21 19:46:48 -0800683 std::unique_ptr<Evaluator> evaluator(
684 new ProgramEvaluator<ScratchEvaluatePreparer,
685 CompressedRowJacobianWriter>(evaluator_options,
686 &program));
687
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800688 if (residuals != nullptr) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800689 residuals->resize(evaluator->NumResiduals());
690 }
691
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800692 if (gradient != nullptr) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800693 gradient->resize(evaluator->NumEffectiveParameters());
694 }
695
696 std::unique_ptr<CompressedRowSparseMatrix> tmp_jacobian;
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800697 if (jacobian != nullptr) {
Austin Schuh3de38b02024-06-25 18:25:10 -0700698 tmp_jacobian.reset(down_cast<CompressedRowSparseMatrix*>(
699 evaluator->CreateJacobian().release()));
Austin Schuh70cc9552019-01-21 19:46:48 -0800700 }
701
702 // Point the state pointers to the user state pointers. This is
703 // needed so that we can extract a parameter vector which is then
704 // passed to Evaluator::Evaluate.
705 program.SetParameterBlockStatePtrsToUserStatePtrs();
706
707 // Copy the value of the parameter blocks into a vector, since the
708 // Evaluate::Evaluate method needs its input as such. The previous
709 // call to SetParameterBlockStatePtrsToUserStatePtrs ensures that
710 // these values are the ones corresponding to the actual state of
711 // the parameter blocks, rather than the temporary state pointer
712 // used for evaluation.
713 Vector parameters(program.NumParameters());
714 program.ParameterBlocksToStateVector(parameters.data());
715
716 double tmp_cost = 0;
717
718 Evaluator::EvaluateOptions evaluator_evaluate_options;
719 evaluator_evaluate_options.apply_loss_function =
720 evaluate_options.apply_loss_function;
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800721 bool status =
722 evaluator->Evaluate(evaluator_evaluate_options,
723 parameters.data(),
724 &tmp_cost,
725 residuals != nullptr ? &(*residuals)[0] : nullptr,
726 gradient != nullptr ? &(*gradient)[0] : nullptr,
727 tmp_jacobian.get());
Austin Schuh70cc9552019-01-21 19:46:48 -0800728
729 // Make the parameter blocks that were temporarily marked constant,
730 // variable again.
Austin Schuh3de38b02024-06-25 18:25:10 -0700731 for (auto* parameter_block : variable_parameter_blocks) {
732 parameter_block->SetVarying();
Austin Schuh70cc9552019-01-21 19:46:48 -0800733 }
734
735 if (status) {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800736 if (cost != nullptr) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800737 *cost = tmp_cost;
738 }
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800739 if (jacobian != nullptr) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800740 tmp_jacobian->ToCRSMatrix(jacobian);
741 }
742 }
743
744 program_->SetParameterBlockStatePtrsToUserStatePtrs();
745 program_->SetParameterOffsetsAndIndex();
746 return status;
747}
748
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800749bool ProblemImpl::EvaluateResidualBlock(ResidualBlock* residual_block,
750 bool apply_loss_function,
751 bool new_point,
752 double* cost,
753 double* residuals,
754 double** jacobians) const {
755 auto evaluation_callback = program_->mutable_evaluation_callback();
756 if (evaluation_callback) {
757 evaluation_callback->PrepareForEvaluation(jacobians != nullptr, new_point);
758 }
759
760 ParameterBlock* const* parameter_blocks = residual_block->parameter_blocks();
761 const int num_parameter_blocks = residual_block->NumParameterBlocks();
762 for (int i = 0; i < num_parameter_blocks; ++i) {
763 ParameterBlock* parameter_block = parameter_blocks[i];
764 if (parameter_block->IsConstant()) {
765 if (jacobians != nullptr && jacobians[i] != nullptr) {
766 LOG(ERROR) << "Jacobian requested for parameter block : " << i
767 << ". But the parameter block is marked constant.";
768 return false;
769 }
770 } else {
771 CHECK(parameter_block->SetState(parameter_block->user_state()))
772 << "Congratulations, you found a Ceres bug! Please report this error "
773 << "to the developers.";
774 }
775 }
776
777 double dummy_cost = 0.0;
778 FixedArray<double, 32> scratch(
779 residual_block->NumScratchDoublesForEvaluate());
780 return residual_block->Evaluate(apply_loss_function,
781 cost ? cost : &dummy_cost,
782 residuals,
783 jacobians,
784 scratch.data());
785}
786
Austin Schuh70cc9552019-01-21 19:46:48 -0800787int ProblemImpl::NumParameterBlocks() const {
788 return program_->NumParameterBlocks();
789}
790
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800791int ProblemImpl::NumParameters() const { return program_->NumParameters(); }
Austin Schuh70cc9552019-01-21 19:46:48 -0800792
793int ProblemImpl::NumResidualBlocks() const {
794 return program_->NumResidualBlocks();
795}
796
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800797int ProblemImpl::NumResiduals() const { return program_->NumResiduals(); }
Austin Schuh70cc9552019-01-21 19:46:48 -0800798
799int ProblemImpl::ParameterBlockSize(const double* values) const {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800800 ParameterBlock* parameter_block = FindWithDefault(
801 parameter_block_map_, const_cast<double*>(values), nullptr);
802 if (parameter_block == nullptr) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800803 LOG(FATAL) << "Parameter block not found: " << values
804 << ". You must add the parameter block to the problem before "
805 << "you can get its size.";
806 }
807
808 return parameter_block->Size();
809}
810
Austin Schuh3de38b02024-06-25 18:25:10 -0700811int ProblemImpl::ParameterBlockTangentSize(const double* values) const {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800812 ParameterBlock* parameter_block = FindWithDefault(
813 parameter_block_map_, const_cast<double*>(values), nullptr);
814 if (parameter_block == nullptr) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800815 LOG(FATAL) << "Parameter block not found: " << values
816 << ". You must add the parameter block to the problem before "
Austin Schuh3de38b02024-06-25 18:25:10 -0700817 << "you can get its tangent size.";
Austin Schuh70cc9552019-01-21 19:46:48 -0800818 }
819
Austin Schuh3de38b02024-06-25 18:25:10 -0700820 return parameter_block->TangentSize();
Austin Schuh70cc9552019-01-21 19:46:48 -0800821}
822
Austin Schuh3de38b02024-06-25 18:25:10 -0700823bool ProblemImpl::HasParameterBlock(const double* values) const {
824 return (parameter_block_map_.find(const_cast<double*>(values)) !=
Austin Schuh70cc9552019-01-21 19:46:48 -0800825 parameter_block_map_.end());
826}
827
Austin Schuh3de38b02024-06-25 18:25:10 -0700828void ProblemImpl::GetParameterBlocks(
829 std::vector<double*>* parameter_blocks) const {
Austin Schuh70cc9552019-01-21 19:46:48 -0800830 CHECK(parameter_blocks != nullptr);
831 parameter_blocks->resize(0);
832 parameter_blocks->reserve(parameter_block_map_.size());
833 for (const auto& entry : parameter_block_map_) {
834 parameter_blocks->push_back(entry.first);
835 }
836}
837
838void ProblemImpl::GetResidualBlocks(
Austin Schuh3de38b02024-06-25 18:25:10 -0700839 std::vector<ResidualBlockId>* residual_blocks) const {
Austin Schuh70cc9552019-01-21 19:46:48 -0800840 CHECK(residual_blocks != nullptr);
841 *residual_blocks = program().residual_blocks();
842}
843
844void ProblemImpl::GetParameterBlocksForResidualBlock(
845 const ResidualBlockId residual_block,
Austin Schuh3de38b02024-06-25 18:25:10 -0700846 std::vector<double*>* parameter_blocks) const {
Austin Schuh70cc9552019-01-21 19:46:48 -0800847 int num_parameter_blocks = residual_block->NumParameterBlocks();
848 CHECK(parameter_blocks != nullptr);
849 parameter_blocks->resize(num_parameter_blocks);
850 for (int i = 0; i < num_parameter_blocks; ++i) {
851 (*parameter_blocks)[i] =
852 residual_block->parameter_blocks()[i]->mutable_user_state();
853 }
854}
855
856const CostFunction* ProblemImpl::GetCostFunctionForResidualBlock(
857 const ResidualBlockId residual_block) const {
858 return residual_block->cost_function();
859}
860
861const LossFunction* ProblemImpl::GetLossFunctionForResidualBlock(
862 const ResidualBlockId residual_block) const {
863 return residual_block->loss_function();
864}
865
866void ProblemImpl::GetResidualBlocksForParameterBlock(
Austin Schuh3de38b02024-06-25 18:25:10 -0700867 const double* values, std::vector<ResidualBlockId>* residual_blocks) const {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800868 ParameterBlock* parameter_block = FindWithDefault(
869 parameter_block_map_, const_cast<double*>(values), nullptr);
870 if (parameter_block == nullptr) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800871 LOG(FATAL) << "Parameter block not found: " << values
872 << ". You must add the parameter block to the problem before "
873 << "you can get the residual blocks that depend on it.";
874 }
875
876 if (options_.enable_fast_removal) {
877 // In this case the residual blocks that depend on the parameter block are
878 // stored in the parameter block already, so just copy them out.
879 CHECK(residual_blocks != nullptr);
880 residual_blocks->resize(parameter_block->mutable_residual_blocks()->size());
881 std::copy(parameter_block->mutable_residual_blocks()->begin(),
882 parameter_block->mutable_residual_blocks()->end(),
883 residual_blocks->begin());
884 return;
885 }
886
887 // Find residual blocks that depend on the parameter block.
888 CHECK(residual_blocks != nullptr);
889 residual_blocks->clear();
890 const int num_residual_blocks = NumResidualBlocks();
891 for (int i = 0; i < num_residual_blocks; ++i) {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800892 ResidualBlock* residual_block = (*(program_->mutable_residual_blocks()))[i];
Austin Schuh70cc9552019-01-21 19:46:48 -0800893 const int num_parameter_blocks = residual_block->NumParameterBlocks();
894 for (int j = 0; j < num_parameter_blocks; ++j) {
895 if (residual_block->parameter_blocks()[j] == parameter_block) {
896 residual_blocks->push_back(residual_block);
897 // The parameter blocks are guaranteed unique.
898 break;
899 }
900 }
901 }
902}
903
Austin Schuh3de38b02024-06-25 18:25:10 -0700904} // namespace ceres::internal