blob: 2eb6d627167d3bef63aefe77694fe36a1dffb7ec [file] [log] [blame]
Austin Schuh70cc9552019-01-21 19:46:48 -08001// Ceres Solver - A fast non-linear least squares minimizer
2// Copyright 2015 Google Inc. All rights reserved.
3// http://ceres-solver.org/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Authors: keir@google.com (Keir Mierle),
30// dgossow@google.com (David Gossow)
31
32#include "ceres/gradient_checking_cost_function.h"
33
34#include <algorithm>
35#include <cmath>
36#include <cstdint>
37#include <numeric>
38#include <string>
39#include <vector>
40
Austin Schuh1d1e6ea2020-12-23 21:56:30 -080041#include "ceres/dynamic_numeric_diff_cost_function.h"
Austin Schuh70cc9552019-01-21 19:46:48 -080042#include "ceres/gradient_checker.h"
43#include "ceres/internal/eigen.h"
44#include "ceres/parameter_block.h"
45#include "ceres/problem.h"
46#include "ceres/problem_impl.h"
47#include "ceres/program.h"
48#include "ceres/residual_block.h"
Austin Schuh70cc9552019-01-21 19:46:48 -080049#include "ceres/stringprintf.h"
50#include "ceres/types.h"
51#include "glog/logging.h"
52
53namespace ceres {
54namespace internal {
55
56using std::abs;
57using std::max;
58using std::string;
59using std::vector;
60
61namespace {
62
63class GradientCheckingCostFunction : public CostFunction {
64 public:
65 GradientCheckingCostFunction(
66 const CostFunction* function,
67 const std::vector<const LocalParameterization*>* local_parameterizations,
68 const NumericDiffOptions& options,
69 double relative_precision,
70 const string& extra_info,
71 GradientCheckingIterationCallback* callback)
72 : function_(function),
73 gradient_checker_(function, local_parameterizations, options),
74 relative_precision_(relative_precision),
75 extra_info_(extra_info),
76 callback_(callback) {
77 CHECK(callback_ != nullptr);
78 const vector<int32_t>& parameter_block_sizes =
79 function->parameter_block_sizes();
80 *mutable_parameter_block_sizes() = parameter_block_sizes;
81 set_num_residuals(function->num_residuals());
82 }
83
Austin Schuh1d1e6ea2020-12-23 21:56:30 -080084 virtual ~GradientCheckingCostFunction() {}
Austin Schuh70cc9552019-01-21 19:46:48 -080085
Austin Schuh1d1e6ea2020-12-23 21:56:30 -080086 bool Evaluate(double const* const* parameters,
87 double* residuals,
88 double** jacobians) const final {
Austin Schuh70cc9552019-01-21 19:46:48 -080089 if (!jacobians) {
90 // Nothing to check in this case; just forward.
91 return function_->Evaluate(parameters, residuals, NULL);
92 }
93
94 GradientChecker::ProbeResults results;
Austin Schuh1d1e6ea2020-12-23 21:56:30 -080095 bool okay =
96 gradient_checker_.Probe(parameters, relative_precision_, &results);
Austin Schuh70cc9552019-01-21 19:46:48 -080097
98 // If the cost function returned false, there's nothing we can say about
99 // the gradients.
100 if (results.return_value == false) {
101 return false;
102 }
103
104 // Copy the residuals.
105 const int num_residuals = function_->num_residuals();
106 MatrixRef(residuals, num_residuals, 1) = results.residuals;
107
108 // Copy the original jacobian blocks into the jacobians array.
109 const vector<int32_t>& block_sizes = function_->parameter_block_sizes();
110 for (int k = 0; k < block_sizes.size(); k++) {
111 if (jacobians[k] != NULL) {
112 MatrixRef(jacobians[k],
113 results.jacobians[k].rows(),
114 results.jacobians[k].cols()) = results.jacobians[k];
115 }
116 }
117
118 if (!okay) {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800119 std::string error_log =
120 "Gradient Error detected!\nExtra info for this residual: " +
121 extra_info_ + "\n" + results.error_log;
Austin Schuh70cc9552019-01-21 19:46:48 -0800122 callback_->SetGradientErrorDetected(error_log);
123 }
124 return true;
125 }
126
127 private:
128 const CostFunction* function_;
129 GradientChecker gradient_checker_;
130 double relative_precision_;
131 string extra_info_;
132 GradientCheckingIterationCallback* callback_;
133};
134
135} // namespace
136
137GradientCheckingIterationCallback::GradientCheckingIterationCallback()
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800138 : gradient_error_detected_(false) {}
Austin Schuh70cc9552019-01-21 19:46:48 -0800139
140CallbackReturnType GradientCheckingIterationCallback::operator()(
141 const IterationSummary& summary) {
142 if (gradient_error_detected_) {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800143 LOG(ERROR) << "Gradient error detected. Terminating solver.";
Austin Schuh70cc9552019-01-21 19:46:48 -0800144 return SOLVER_ABORT;
145 }
146 return SOLVER_CONTINUE;
147}
148void GradientCheckingIterationCallback::SetGradientErrorDetected(
149 std::string& error_log) {
150 std::lock_guard<std::mutex> l(mutex_);
151 gradient_error_detected_ = true;
152 error_log_ += "\n" + error_log;
153}
154
155CostFunction* CreateGradientCheckingCostFunction(
156 const CostFunction* cost_function,
157 const std::vector<const LocalParameterization*>* local_parameterizations,
158 double relative_step_size,
159 double relative_precision,
160 const std::string& extra_info,
161 GradientCheckingIterationCallback* callback) {
162 NumericDiffOptions numeric_diff_options;
163 numeric_diff_options.relative_step_size = relative_step_size;
164
165 return new GradientCheckingCostFunction(cost_function,
166 local_parameterizations,
167 numeric_diff_options,
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800168 relative_precision,
169 extra_info,
Austin Schuh70cc9552019-01-21 19:46:48 -0800170 callback);
171}
172
173ProblemImpl* CreateGradientCheckingProblemImpl(
174 ProblemImpl* problem_impl,
175 double relative_step_size,
176 double relative_precision,
177 GradientCheckingIterationCallback* callback) {
178 CHECK(callback != nullptr);
179 // We create new CostFunctions by wrapping the original CostFunction
180 // in a gradient checking CostFunction. So its okay for the
181 // ProblemImpl to take ownership of it and destroy it. The
182 // LossFunctions and LocalParameterizations are reused and since
183 // they are owned by problem_impl, gradient_checking_problem_impl
184 // should not take ownership of it.
185 Problem::Options gradient_checking_problem_options;
186 gradient_checking_problem_options.cost_function_ownership = TAKE_OWNERSHIP;
187 gradient_checking_problem_options.loss_function_ownership =
188 DO_NOT_TAKE_OWNERSHIP;
189 gradient_checking_problem_options.local_parameterization_ownership =
190 DO_NOT_TAKE_OWNERSHIP;
191 gradient_checking_problem_options.context = problem_impl->context();
192
193 NumericDiffOptions numeric_diff_options;
194 numeric_diff_options.relative_step_size = relative_step_size;
195
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800196 ProblemImpl* gradient_checking_problem_impl =
197 new ProblemImpl(gradient_checking_problem_options);
Austin Schuh70cc9552019-01-21 19:46:48 -0800198
199 Program* program = problem_impl->mutable_program();
200
201 // For every ParameterBlock in problem_impl, create a new parameter
202 // block with the same local parameterization and constancy.
203 const vector<ParameterBlock*>& parameter_blocks = program->parameter_blocks();
204 for (int i = 0; i < parameter_blocks.size(); ++i) {
205 ParameterBlock* parameter_block = parameter_blocks[i];
206 gradient_checking_problem_impl->AddParameterBlock(
207 parameter_block->mutable_user_state(),
208 parameter_block->Size(),
209 parameter_block->mutable_local_parameterization());
210
211 if (parameter_block->IsConstant()) {
212 gradient_checking_problem_impl->SetParameterBlockConstant(
213 parameter_block->mutable_user_state());
214 }
215
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800216 for (int i = 0; i < parameter_block->Size(); ++i) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800217 gradient_checking_problem_impl->SetParameterUpperBound(
218 parameter_block->mutable_user_state(),
219 i,
220 parameter_block->UpperBound(i));
221 gradient_checking_problem_impl->SetParameterLowerBound(
222 parameter_block->mutable_user_state(),
223 i,
224 parameter_block->LowerBound(i));
225 }
226 }
227
228 // For every ResidualBlock in problem_impl, create a new
229 // ResidualBlock by wrapping its CostFunction inside a
230 // GradientCheckingCostFunction.
231 const vector<ResidualBlock*>& residual_blocks = program->residual_blocks();
232 for (int i = 0; i < residual_blocks.size(); ++i) {
233 ResidualBlock* residual_block = residual_blocks[i];
234
235 // Build a human readable string which identifies the
236 // ResidualBlock. This is used by the GradientCheckingCostFunction
237 // when logging debugging information.
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800238 string extra_info =
239 StringPrintf("Residual block id %d; depends on parameters [", i);
Austin Schuh70cc9552019-01-21 19:46:48 -0800240 vector<double*> parameter_blocks;
241 vector<const LocalParameterization*> local_parameterizations;
242 parameter_blocks.reserve(residual_block->NumParameterBlocks());
243 local_parameterizations.reserve(residual_block->NumParameterBlocks());
244 for (int j = 0; j < residual_block->NumParameterBlocks(); ++j) {
245 ParameterBlock* parameter_block = residual_block->parameter_blocks()[j];
246 parameter_blocks.push_back(parameter_block->mutable_user_state());
247 StringAppendF(&extra_info, "%p", parameter_block->mutable_user_state());
248 extra_info += (j < residual_block->NumParameterBlocks() - 1) ? ", " : "]";
249 local_parameterizations.push_back(problem_impl->GetParameterization(
250 parameter_block->mutable_user_state()));
251 }
252
253 // Wrap the original CostFunction in a GradientCheckingCostFunction.
254 CostFunction* gradient_checking_cost_function =
255 new GradientCheckingCostFunction(residual_block->cost_function(),
256 &local_parameterizations,
257 numeric_diff_options,
258 relative_precision,
259 extra_info,
260 callback);
261
262 // The const_cast is necessary because
263 // ProblemImpl::AddResidualBlock can potentially take ownership of
264 // the LossFunction, but in this case we are guaranteed that this
265 // will not be the case, so this const_cast is harmless.
266 gradient_checking_problem_impl->AddResidualBlock(
267 gradient_checking_cost_function,
268 const_cast<LossFunction*>(residual_block->loss_function()),
269 parameter_blocks.data(),
270 static_cast<int>(parameter_blocks.size()));
271 }
272
273 // Normally, when a problem is given to the solver, we guarantee
274 // that the state pointers for each parameter block point to the
275 // user provided data. Since we are creating this new problem from a
276 // problem given to us at an arbitrary stage of the solve, we cannot
277 // depend on this being the case, so we explicitly call
278 // SetParameterBlockStatePtrsToUserStatePtrs to ensure that this is
279 // the case.
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800280 gradient_checking_problem_impl->mutable_program()
Austin Schuh70cc9552019-01-21 19:46:48 -0800281 ->SetParameterBlockStatePtrsToUserStatePtrs();
282
283 return gradient_checking_problem_impl;
284}
285
Austin Schuh70cc9552019-01-21 19:46:48 -0800286} // namespace internal
287} // namespace ceres