blob: 8ca449b4e26b90f68d4af09e097a0335b0740562 [file] [log] [blame]
Austin Schuh70cc9552019-01-21 19:46:48 -08001// Ceres Solver - A fast non-linear least squares minimizer
Austin Schuh3de38b02024-06-25 18:25:10 -07002// Copyright 2023 Google Inc. All rights reserved.
Austin Schuh70cc9552019-01-21 19:46:48 -08003// http://ceres-solver.org/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Authors: keir@google.com (Keir Mierle),
30// dgossow@google.com (David Gossow)
31
32#include "ceres/gradient_checking_cost_function.h"
33
34#include <algorithm>
35#include <cmath>
36#include <cstdint>
Austin Schuh3de38b02024-06-25 18:25:10 -070037#include <memory>
Austin Schuh70cc9552019-01-21 19:46:48 -080038#include <numeric>
39#include <string>
Austin Schuh3de38b02024-06-25 18:25:10 -070040#include <utility>
Austin Schuh70cc9552019-01-21 19:46:48 -080041#include <vector>
42
Austin Schuh1d1e6ea2020-12-23 21:56:30 -080043#include "ceres/dynamic_numeric_diff_cost_function.h"
Austin Schuh70cc9552019-01-21 19:46:48 -080044#include "ceres/gradient_checker.h"
45#include "ceres/internal/eigen.h"
46#include "ceres/parameter_block.h"
47#include "ceres/problem.h"
48#include "ceres/problem_impl.h"
49#include "ceres/program.h"
50#include "ceres/residual_block.h"
Austin Schuh70cc9552019-01-21 19:46:48 -080051#include "ceres/stringprintf.h"
52#include "ceres/types.h"
53#include "glog/logging.h"
54
Austin Schuh3de38b02024-06-25 18:25:10 -070055namespace ceres::internal {
Austin Schuh70cc9552019-01-21 19:46:48 -080056
57namespace {
58
Austin Schuh3de38b02024-06-25 18:25:10 -070059class GradientCheckingCostFunction final : public CostFunction {
Austin Schuh70cc9552019-01-21 19:46:48 -080060 public:
Austin Schuh3de38b02024-06-25 18:25:10 -070061 GradientCheckingCostFunction(const CostFunction* function,
62 const std::vector<const Manifold*>* manifolds,
63 const NumericDiffOptions& options,
64 double relative_precision,
65 std::string extra_info,
66 GradientCheckingIterationCallback* callback)
Austin Schuh70cc9552019-01-21 19:46:48 -080067 : function_(function),
Austin Schuh3de38b02024-06-25 18:25:10 -070068 gradient_checker_(function, manifolds, options),
Austin Schuh70cc9552019-01-21 19:46:48 -080069 relative_precision_(relative_precision),
Austin Schuh3de38b02024-06-25 18:25:10 -070070 extra_info_(std::move(extra_info)),
Austin Schuh70cc9552019-01-21 19:46:48 -080071 callback_(callback) {
72 CHECK(callback_ != nullptr);
Austin Schuh3de38b02024-06-25 18:25:10 -070073 const std::vector<int32_t>& parameter_block_sizes =
Austin Schuh70cc9552019-01-21 19:46:48 -080074 function->parameter_block_sizes();
75 *mutable_parameter_block_sizes() = parameter_block_sizes;
76 set_num_residuals(function->num_residuals());
77 }
78
Austin Schuh1d1e6ea2020-12-23 21:56:30 -080079 bool Evaluate(double const* const* parameters,
80 double* residuals,
81 double** jacobians) const final {
Austin Schuh70cc9552019-01-21 19:46:48 -080082 if (!jacobians) {
83 // Nothing to check in this case; just forward.
Austin Schuh3de38b02024-06-25 18:25:10 -070084 return function_->Evaluate(parameters, residuals, nullptr);
Austin Schuh70cc9552019-01-21 19:46:48 -080085 }
86
87 GradientChecker::ProbeResults results;
Austin Schuh1d1e6ea2020-12-23 21:56:30 -080088 bool okay =
89 gradient_checker_.Probe(parameters, relative_precision_, &results);
Austin Schuh70cc9552019-01-21 19:46:48 -080090
91 // If the cost function returned false, there's nothing we can say about
92 // the gradients.
93 if (results.return_value == false) {
94 return false;
95 }
96
97 // Copy the residuals.
98 const int num_residuals = function_->num_residuals();
99 MatrixRef(residuals, num_residuals, 1) = results.residuals;
100
101 // Copy the original jacobian blocks into the jacobians array.
Austin Schuh3de38b02024-06-25 18:25:10 -0700102 const std::vector<int32_t>& block_sizes =
103 function_->parameter_block_sizes();
Austin Schuh70cc9552019-01-21 19:46:48 -0800104 for (int k = 0; k < block_sizes.size(); k++) {
Austin Schuh3de38b02024-06-25 18:25:10 -0700105 if (jacobians[k] != nullptr) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800106 MatrixRef(jacobians[k],
107 results.jacobians[k].rows(),
108 results.jacobians[k].cols()) = results.jacobians[k];
109 }
110 }
111
112 if (!okay) {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800113 std::string error_log =
114 "Gradient Error detected!\nExtra info for this residual: " +
115 extra_info_ + "\n" + results.error_log;
Austin Schuh70cc9552019-01-21 19:46:48 -0800116 callback_->SetGradientErrorDetected(error_log);
117 }
118 return true;
119 }
120
121 private:
122 const CostFunction* function_;
123 GradientChecker gradient_checker_;
124 double relative_precision_;
Austin Schuh3de38b02024-06-25 18:25:10 -0700125 std::string extra_info_;
Austin Schuh70cc9552019-01-21 19:46:48 -0800126 GradientCheckingIterationCallback* callback_;
127};
128
129} // namespace
130
131GradientCheckingIterationCallback::GradientCheckingIterationCallback()
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800132 : gradient_error_detected_(false) {}
Austin Schuh70cc9552019-01-21 19:46:48 -0800133
134CallbackReturnType GradientCheckingIterationCallback::operator()(
Austin Schuh3de38b02024-06-25 18:25:10 -0700135 const IterationSummary& /*summary*/) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800136 if (gradient_error_detected_) {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800137 LOG(ERROR) << "Gradient error detected. Terminating solver.";
Austin Schuh70cc9552019-01-21 19:46:48 -0800138 return SOLVER_ABORT;
139 }
140 return SOLVER_CONTINUE;
141}
Austin Schuh3de38b02024-06-25 18:25:10 -0700142
Austin Schuh70cc9552019-01-21 19:46:48 -0800143void GradientCheckingIterationCallback::SetGradientErrorDetected(
144 std::string& error_log) {
145 std::lock_guard<std::mutex> l(mutex_);
146 gradient_error_detected_ = true;
147 error_log_ += "\n" + error_log;
148}
149
Austin Schuh3de38b02024-06-25 18:25:10 -0700150std::unique_ptr<CostFunction> CreateGradientCheckingCostFunction(
Austin Schuh70cc9552019-01-21 19:46:48 -0800151 const CostFunction* cost_function,
Austin Schuh3de38b02024-06-25 18:25:10 -0700152 const std::vector<const Manifold*>* manifolds,
Austin Schuh70cc9552019-01-21 19:46:48 -0800153 double relative_step_size,
154 double relative_precision,
155 const std::string& extra_info,
156 GradientCheckingIterationCallback* callback) {
157 NumericDiffOptions numeric_diff_options;
158 numeric_diff_options.relative_step_size = relative_step_size;
159
Austin Schuh3de38b02024-06-25 18:25:10 -0700160 return std::make_unique<GradientCheckingCostFunction>(cost_function,
161 manifolds,
162 numeric_diff_options,
163 relative_precision,
164 extra_info,
165 callback);
Austin Schuh70cc9552019-01-21 19:46:48 -0800166}
167
Austin Schuh3de38b02024-06-25 18:25:10 -0700168std::unique_ptr<ProblemImpl> CreateGradientCheckingProblemImpl(
Austin Schuh70cc9552019-01-21 19:46:48 -0800169 ProblemImpl* problem_impl,
170 double relative_step_size,
171 double relative_precision,
172 GradientCheckingIterationCallback* callback) {
173 CHECK(callback != nullptr);
Austin Schuh3de38b02024-06-25 18:25:10 -0700174 // We create new CostFunctions by wrapping the original CostFunction in a
175 // gradient checking CostFunction. So its okay for the ProblemImpl to take
176 // ownership of it and destroy it. The LossFunctions and Manifolds are reused
177 // and since they are owned by problem_impl, gradient_checking_problem_impl
Austin Schuh70cc9552019-01-21 19:46:48 -0800178 // should not take ownership of it.
179 Problem::Options gradient_checking_problem_options;
180 gradient_checking_problem_options.cost_function_ownership = TAKE_OWNERSHIP;
181 gradient_checking_problem_options.loss_function_ownership =
182 DO_NOT_TAKE_OWNERSHIP;
Austin Schuh3de38b02024-06-25 18:25:10 -0700183 gradient_checking_problem_options.manifold_ownership = DO_NOT_TAKE_OWNERSHIP;
Austin Schuh70cc9552019-01-21 19:46:48 -0800184 gradient_checking_problem_options.context = problem_impl->context();
185
186 NumericDiffOptions numeric_diff_options;
187 numeric_diff_options.relative_step_size = relative_step_size;
188
Austin Schuh3de38b02024-06-25 18:25:10 -0700189 auto gradient_checking_problem_impl =
190 std::make_unique<ProblemImpl>(gradient_checking_problem_options);
Austin Schuh70cc9552019-01-21 19:46:48 -0800191
192 Program* program = problem_impl->mutable_program();
193
Austin Schuh3de38b02024-06-25 18:25:10 -0700194 // For every ParameterBlock in problem_impl, create a new parameter block with
195 // the same manifold and constancy.
196 const std::vector<ParameterBlock*>& parameter_blocks =
197 program->parameter_blocks();
198 for (auto* parameter_block : parameter_blocks) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800199 gradient_checking_problem_impl->AddParameterBlock(
200 parameter_block->mutable_user_state(),
201 parameter_block->Size(),
Austin Schuh3de38b02024-06-25 18:25:10 -0700202 parameter_block->mutable_manifold());
Austin Schuh70cc9552019-01-21 19:46:48 -0800203
204 if (parameter_block->IsConstant()) {
205 gradient_checking_problem_impl->SetParameterBlockConstant(
206 parameter_block->mutable_user_state());
207 }
208
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800209 for (int i = 0; i < parameter_block->Size(); ++i) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800210 gradient_checking_problem_impl->SetParameterUpperBound(
211 parameter_block->mutable_user_state(),
212 i,
213 parameter_block->UpperBound(i));
214 gradient_checking_problem_impl->SetParameterLowerBound(
215 parameter_block->mutable_user_state(),
216 i,
217 parameter_block->LowerBound(i));
218 }
219 }
220
221 // For every ResidualBlock in problem_impl, create a new
222 // ResidualBlock by wrapping its CostFunction inside a
223 // GradientCheckingCostFunction.
Austin Schuh3de38b02024-06-25 18:25:10 -0700224 const std::vector<ResidualBlock*>& residual_blocks =
225 program->residual_blocks();
Austin Schuh70cc9552019-01-21 19:46:48 -0800226 for (int i = 0; i < residual_blocks.size(); ++i) {
227 ResidualBlock* residual_block = residual_blocks[i];
228
229 // Build a human readable string which identifies the
230 // ResidualBlock. This is used by the GradientCheckingCostFunction
231 // when logging debugging information.
Austin Schuh3de38b02024-06-25 18:25:10 -0700232 std::string extra_info =
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800233 StringPrintf("Residual block id %d; depends on parameters [", i);
Austin Schuh3de38b02024-06-25 18:25:10 -0700234 std::vector<double*> parameter_blocks;
235 std::vector<const Manifold*> manifolds;
Austin Schuh70cc9552019-01-21 19:46:48 -0800236 parameter_blocks.reserve(residual_block->NumParameterBlocks());
Austin Schuh3de38b02024-06-25 18:25:10 -0700237 manifolds.reserve(residual_block->NumParameterBlocks());
Austin Schuh70cc9552019-01-21 19:46:48 -0800238 for (int j = 0; j < residual_block->NumParameterBlocks(); ++j) {
239 ParameterBlock* parameter_block = residual_block->parameter_blocks()[j];
240 parameter_blocks.push_back(parameter_block->mutable_user_state());
241 StringAppendF(&extra_info, "%p", parameter_block->mutable_user_state());
242 extra_info += (j < residual_block->NumParameterBlocks() - 1) ? ", " : "]";
Austin Schuh3de38b02024-06-25 18:25:10 -0700243 manifolds.push_back(
244 problem_impl->GetManifold(parameter_block->mutable_user_state()));
Austin Schuh70cc9552019-01-21 19:46:48 -0800245 }
246
247 // Wrap the original CostFunction in a GradientCheckingCostFunction.
248 CostFunction* gradient_checking_cost_function =
249 new GradientCheckingCostFunction(residual_block->cost_function(),
Austin Schuh3de38b02024-06-25 18:25:10 -0700250 &manifolds,
Austin Schuh70cc9552019-01-21 19:46:48 -0800251 numeric_diff_options,
252 relative_precision,
253 extra_info,
254 callback);
255
256 // The const_cast is necessary because
257 // ProblemImpl::AddResidualBlock can potentially take ownership of
258 // the LossFunction, but in this case we are guaranteed that this
259 // will not be the case, so this const_cast is harmless.
260 gradient_checking_problem_impl->AddResidualBlock(
261 gradient_checking_cost_function,
262 const_cast<LossFunction*>(residual_block->loss_function()),
263 parameter_blocks.data(),
264 static_cast<int>(parameter_blocks.size()));
265 }
266
267 // Normally, when a problem is given to the solver, we guarantee
268 // that the state pointers for each parameter block point to the
269 // user provided data. Since we are creating this new problem from a
270 // problem given to us at an arbitrary stage of the solve, we cannot
271 // depend on this being the case, so we explicitly call
272 // SetParameterBlockStatePtrsToUserStatePtrs to ensure that this is
273 // the case.
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800274 gradient_checking_problem_impl->mutable_program()
Austin Schuh70cc9552019-01-21 19:46:48 -0800275 ->SetParameterBlockStatePtrsToUserStatePtrs();
276
277 return gradient_checking_problem_impl;
278}
279
Austin Schuh3de38b02024-06-25 18:25:10 -0700280} // namespace ceres::internal