blob: ac0a19289a2badc3bfb13b4ecf3c5d6bae956414 [file] [log] [blame]
Austin Schuh70cc9552019-01-21 19:46:48 -08001// Ceres Solver - A fast non-linear least squares minimizer
2// Copyright 2015 Google Inc. All rights reserved.
3// http://ceres-solver.org/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Author: sameeragarwal@google.com (Sameer Agarwal)
30//
31// Generic loop for line search based optimization algorithms.
32//
33// This is primarily inpsired by the minFunc packaged written by Mark
34// Schmidt.
35//
36// http://www.di.ens.fr/~mschmidt/Software/minFunc.html
37//
38// For details on the theory and implementation see "Numerical
39// Optimization" by Nocedal & Wright.
40
41#include "ceres/line_search_minimizer.h"
42
43#include <algorithm>
44#include <cstdlib>
45#include <cmath>
46#include <memory>
47#include <string>
48#include <vector>
49
50#include "Eigen/Dense"
51#include "ceres/array_utils.h"
52#include "ceres/evaluator.h"
53#include "ceres/internal/eigen.h"
54#include "ceres/internal/port.h"
55#include "ceres/line_search.h"
56#include "ceres/line_search_direction.h"
57#include "ceres/stringprintf.h"
58#include "ceres/types.h"
59#include "ceres/wall_time.h"
60#include "glog/logging.h"
61
62namespace ceres {
63namespace internal {
64namespace {
65
66bool EvaluateGradientNorms(Evaluator* evaluator,
67 const Vector& x,
68 LineSearchMinimizer::State* state,
69 std::string* message) {
70 Vector negative_gradient = -state->gradient;
71 Vector projected_gradient_step(x.size());
72 if (!evaluator->Plus(
73 x.data(), negative_gradient.data(), projected_gradient_step.data())) {
74 *message = "projected_gradient_step = Plus(x, -gradient) failed.";
75 return false;
76 }
77
78 state->gradient_squared_norm = (x - projected_gradient_step).squaredNorm();
79 state->gradient_max_norm =
80 (x - projected_gradient_step).lpNorm<Eigen::Infinity>();
81 return true;
82}
83
84} // namespace
85
86void LineSearchMinimizer::Minimize(const Minimizer::Options& options,
87 double* parameters,
88 Solver::Summary* summary) {
89 const bool is_not_silent = !options.is_silent;
90 double start_time = WallTimeInSeconds();
91 double iteration_start_time = start_time;
92
93 CHECK(options.evaluator != nullptr);
94 Evaluator* evaluator = options.evaluator.get();
95 const int num_parameters = evaluator->NumParameters();
96 const int num_effective_parameters = evaluator->NumEffectiveParameters();
97
98 summary->termination_type = NO_CONVERGENCE;
99 summary->num_successful_steps = 0;
100 summary->num_unsuccessful_steps = 0;
101
102 VectorRef x(parameters, num_parameters);
103
104 State current_state(num_parameters, num_effective_parameters);
105 State previous_state(num_parameters, num_effective_parameters);
106
107 IterationSummary iteration_summary;
108 iteration_summary.iteration = 0;
109 iteration_summary.step_is_valid = false;
110 iteration_summary.step_is_successful = false;
111 iteration_summary.cost_change = 0.0;
112 iteration_summary.gradient_max_norm = 0.0;
113 iteration_summary.gradient_norm = 0.0;
114 iteration_summary.step_norm = 0.0;
115 iteration_summary.linear_solver_iterations = 0;
116 iteration_summary.step_solver_time_in_seconds = 0;
117
118 // Do initial cost and gradient evaluation.
119 if (!evaluator->Evaluate(x.data(),
120 &(current_state.cost),
121 NULL,
122 current_state.gradient.data(),
123 NULL)) {
124 summary->termination_type = FAILURE;
125 summary->message = "Initial cost and jacobian evaluation failed.";
126 LOG_IF(WARNING, is_not_silent) << "Terminating: " << summary->message;
127 return;
128 }
129
130 if (!EvaluateGradientNorms(evaluator, x, &current_state, &summary->message)) {
131 summary->termination_type = FAILURE;
132 summary->message = "Initial cost and jacobian evaluation failed. "
133 "More details: " + summary->message;
134 LOG_IF(WARNING, is_not_silent) << "Terminating: " << summary->message;
135 return;
136 }
137
138 summary->initial_cost = current_state.cost + summary->fixed_cost;
139 iteration_summary.cost = current_state.cost + summary->fixed_cost;
140
141 iteration_summary.gradient_norm = sqrt(current_state.gradient_squared_norm);
142 iteration_summary.gradient_max_norm = current_state.gradient_max_norm;
143 if (iteration_summary.gradient_max_norm <= options.gradient_tolerance) {
144 summary->message = StringPrintf("Gradient tolerance reached. "
145 "Gradient max norm: %e <= %e",
146 iteration_summary.gradient_max_norm,
147 options.gradient_tolerance);
148 summary->termination_type = CONVERGENCE;
149 VLOG_IF(1, is_not_silent) << "Terminating: " << summary->message;
150 return;
151 }
152
153 iteration_summary.iteration_time_in_seconds =
154 WallTimeInSeconds() - iteration_start_time;
155 iteration_summary.cumulative_time_in_seconds =
156 WallTimeInSeconds() - start_time
157 + summary->preprocessor_time_in_seconds;
158 summary->iterations.push_back(iteration_summary);
159
160 LineSearchDirection::Options line_search_direction_options;
161 line_search_direction_options.num_parameters = num_effective_parameters;
162 line_search_direction_options.type = options.line_search_direction_type;
163 line_search_direction_options.nonlinear_conjugate_gradient_type =
164 options.nonlinear_conjugate_gradient_type;
165 line_search_direction_options.max_lbfgs_rank = options.max_lbfgs_rank;
166 line_search_direction_options.use_approximate_eigenvalue_bfgs_scaling =
167 options.use_approximate_eigenvalue_bfgs_scaling;
168 std::unique_ptr<LineSearchDirection> line_search_direction(
169 LineSearchDirection::Create(line_search_direction_options));
170
171 LineSearchFunction line_search_function(evaluator);
172
173 LineSearch::Options line_search_options;
174 line_search_options.interpolation_type =
175 options.line_search_interpolation_type;
176 line_search_options.min_step_size = options.min_line_search_step_size;
177 line_search_options.sufficient_decrease =
178 options.line_search_sufficient_function_decrease;
179 line_search_options.max_step_contraction =
180 options.max_line_search_step_contraction;
181 line_search_options.min_step_contraction =
182 options.min_line_search_step_contraction;
183 line_search_options.max_num_iterations =
184 options.max_num_line_search_step_size_iterations;
185 line_search_options.sufficient_curvature_decrease =
186 options.line_search_sufficient_curvature_decrease;
187 line_search_options.max_step_expansion =
188 options.max_line_search_step_expansion;
189 line_search_options.is_silent = options.is_silent;
190 line_search_options.function = &line_search_function;
191
192 std::unique_ptr<LineSearch>
193 line_search(LineSearch::Create(options.line_search_type,
194 line_search_options,
195 &summary->message));
196 if (line_search.get() == NULL) {
197 summary->termination_type = FAILURE;
198 LOG_IF(ERROR, is_not_silent) << "Terminating: " << summary->message;
199 return;
200 }
201
202 LineSearch::Summary line_search_summary;
203 int num_line_search_direction_restarts = 0;
204
205 while (true) {
206 if (!RunCallbacks(options, iteration_summary, summary)) {
207 break;
208 }
209
210 iteration_start_time = WallTimeInSeconds();
211 if (iteration_summary.iteration >= options.max_num_iterations) {
212 summary->message = "Maximum number of iterations reached.";
213 summary->termination_type = NO_CONVERGENCE;
214 VLOG_IF(1, is_not_silent) << "Terminating: " << summary->message;
215 break;
216 }
217
218 const double total_solver_time = iteration_start_time - start_time +
219 summary->preprocessor_time_in_seconds;
220 if (total_solver_time >= options.max_solver_time_in_seconds) {
221 summary->message = "Maximum solver time reached.";
222 summary->termination_type = NO_CONVERGENCE;
223 VLOG_IF(1, is_not_silent) << "Terminating: " << summary->message;
224 break;
225 }
226
227 iteration_summary = IterationSummary();
228 iteration_summary.iteration = summary->iterations.back().iteration + 1;
229 iteration_summary.step_is_valid = false;
230 iteration_summary.step_is_successful = false;
231
232 bool line_search_status = true;
233 if (iteration_summary.iteration == 1) {
234 current_state.search_direction = -current_state.gradient;
235 } else {
236 line_search_status = line_search_direction->NextDirection(
237 previous_state,
238 current_state,
239 &current_state.search_direction);
240 }
241
242 if (!line_search_status &&
243 num_line_search_direction_restarts >=
244 options.max_num_line_search_direction_restarts) {
245 // Line search direction failed to generate a new direction, and we
246 // have already reached our specified maximum number of restarts,
247 // terminate optimization.
248 summary->message =
249 StringPrintf("Line search direction failure: specified "
250 "max_num_line_search_direction_restarts: %d reached.",
251 options.max_num_line_search_direction_restarts);
252 summary->termination_type = FAILURE;
253 LOG_IF(WARNING, is_not_silent) << "Terminating: " << summary->message;
254 break;
255 } else if (!line_search_status) {
256 // Restart line search direction with gradient descent on first iteration
257 // as we have not yet reached our maximum number of restarts.
258 CHECK_LT(num_line_search_direction_restarts,
259 options.max_num_line_search_direction_restarts);
260
261 ++num_line_search_direction_restarts;
262 LOG_IF(WARNING, is_not_silent)
263 << "Line search direction algorithm: "
264 << LineSearchDirectionTypeToString(
265 options.line_search_direction_type)
266 << ", failed to produce a valid new direction at "
267 << "iteration: " << iteration_summary.iteration
268 << ". Restarting, number of restarts: "
269 << num_line_search_direction_restarts << " / "
270 << options.max_num_line_search_direction_restarts
271 << " [max].";
272 line_search_direction.reset(
273 LineSearchDirection::Create(line_search_direction_options));
274 current_state.search_direction = -current_state.gradient;
275 }
276
277 line_search_function.Init(x, current_state.search_direction);
278 current_state.directional_derivative =
279 current_state.gradient.dot(current_state.search_direction);
280
281 // TODO(sameeragarwal): Refactor this into its own object and add
282 // explanations for the various choices.
283 //
284 // Note that we use !line_search_status to ensure that we treat cases when
285 // we restarted the line search direction equivalently to the first
286 // iteration.
287 const double initial_step_size =
288 (iteration_summary.iteration == 1 || !line_search_status)
289 ? std::min(1.0, 1.0 / current_state.gradient_max_norm)
290 : std::min(1.0, 2.0 * (current_state.cost - previous_state.cost) /
291 current_state.directional_derivative);
292 // By definition, we should only ever go forwards along the specified search
293 // direction in a line search, most likely cause for this being violated
294 // would be a numerical failure in the line search direction calculation.
295 if (initial_step_size < 0.0) {
296 summary->message =
297 StringPrintf("Numerical failure in line search, initial_step_size is "
298 "negative: %.5e, directional_derivative: %.5e, "
299 "(current_cost - previous_cost): %.5e",
300 initial_step_size, current_state.directional_derivative,
301 (current_state.cost - previous_state.cost));
302 summary->termination_type = FAILURE;
303 LOG_IF(WARNING, is_not_silent) << "Terminating: " << summary->message;
304 break;
305 }
306
307 line_search->Search(initial_step_size,
308 current_state.cost,
309 current_state.directional_derivative,
310 &line_search_summary);
311 if (!line_search_summary.success) {
312 summary->message =
313 StringPrintf("Numerical failure in line search, failed to find "
314 "a valid step size, (did not run out of iterations) "
315 "using initial_step_size: %.5e, initial_cost: %.5e, "
316 "initial_gradient: %.5e.",
317 initial_step_size, current_state.cost,
318 current_state.directional_derivative);
319 LOG_IF(WARNING, is_not_silent) << "Terminating: " << summary->message;
320 summary->termination_type = FAILURE;
321 break;
322 }
323
324 const FunctionSample& optimal_point = line_search_summary.optimal_point;
325 CHECK(optimal_point.vector_x_is_valid)
326 << "Congratulations, you found a bug in Ceres. Please report it.";
327 current_state.step_size = optimal_point.x;
328 previous_state = current_state;
329 iteration_summary.step_solver_time_in_seconds =
330 WallTimeInSeconds() - iteration_start_time;
331
332 if (optimal_point.vector_gradient_is_valid) {
333 current_state.cost = optimal_point.value;
334 current_state.gradient = optimal_point.vector_gradient;
335 } else {
336 Evaluator::EvaluateOptions evaluate_options;
337 evaluate_options.new_evaluation_point = false;
338 if (!evaluator->Evaluate(evaluate_options,
339 optimal_point.vector_x.data(),
340 &(current_state.cost),
341 NULL,
342 current_state.gradient.data(),
343 NULL)) {
344 summary->termination_type = FAILURE;
345 summary->message = "Cost and jacobian evaluation failed.";
346 LOG_IF(WARNING, is_not_silent) << "Terminating: " << summary->message;
347 return;
348 }
349 }
350
351 if (!EvaluateGradientNorms(evaluator,
352 optimal_point.vector_x,
353 &current_state,
354 &summary->message)) {
355 summary->termination_type = FAILURE;
356 summary->message =
357 "Step failed to evaluate. This should not happen as the step was "
358 "valid when it was selected by the line search. More details: " +
359 summary->message;
360 LOG_IF(WARNING, is_not_silent) << "Terminating: " << summary->message;
361 break;
362 }
363
364 // Compute the norm of the step in the ambient space.
365 iteration_summary.step_norm = (optimal_point.vector_x - x).norm();
366 const double x_norm = x.norm();
367 x = optimal_point.vector_x;
368
369 iteration_summary.gradient_max_norm = current_state.gradient_max_norm;
370 iteration_summary.gradient_norm = sqrt(current_state.gradient_squared_norm);
371 iteration_summary.cost_change = previous_state.cost - current_state.cost;
372 iteration_summary.cost = current_state.cost + summary->fixed_cost;
373
374 iteration_summary.step_is_valid = true;
375 iteration_summary.step_is_successful = true;
376 iteration_summary.step_size = current_state.step_size;
377 iteration_summary.line_search_function_evaluations =
378 line_search_summary.num_function_evaluations;
379 iteration_summary.line_search_gradient_evaluations =
380 line_search_summary.num_gradient_evaluations;
381 iteration_summary.line_search_iterations =
382 line_search_summary.num_iterations;
383 iteration_summary.iteration_time_in_seconds =
384 WallTimeInSeconds() - iteration_start_time;
385 iteration_summary.cumulative_time_in_seconds =
386 WallTimeInSeconds() - start_time
387 + summary->preprocessor_time_in_seconds;
388 summary->iterations.push_back(iteration_summary);
389
390 // Iterations inside the line search algorithm are considered
391 // 'steps' in the broader context, to distinguish these inner
392 // iterations from from the outer iterations of the line search
393 // minimizer. The number of line search steps is the total number
394 // of inner line search iterations (or steps) across the entire
395 // minimization.
396 summary->num_line_search_steps += line_search_summary.num_iterations;
397 summary->line_search_cost_evaluation_time_in_seconds +=
398 line_search_summary.cost_evaluation_time_in_seconds;
399 summary->line_search_gradient_evaluation_time_in_seconds +=
400 line_search_summary.gradient_evaluation_time_in_seconds;
401 summary->line_search_polynomial_minimization_time_in_seconds +=
402 line_search_summary.polynomial_minimization_time_in_seconds;
403 summary->line_search_total_time_in_seconds +=
404 line_search_summary.total_time_in_seconds;
405 ++summary->num_successful_steps;
406
407 const double step_size_tolerance = options.parameter_tolerance *
408 (x_norm + options.parameter_tolerance);
409 if (iteration_summary.step_norm <= step_size_tolerance) {
410 summary->message =
411 StringPrintf("Parameter tolerance reached. "
412 "Relative step_norm: %e <= %e.",
413 (iteration_summary.step_norm /
414 (x_norm + options.parameter_tolerance)),
415 options.parameter_tolerance);
416 summary->termination_type = CONVERGENCE;
417 VLOG_IF(1, is_not_silent) << "Terminating: " << summary->message;
418 return;
419 }
420
421 if (iteration_summary.gradient_max_norm <= options.gradient_tolerance) {
422 summary->message = StringPrintf("Gradient tolerance reached. "
423 "Gradient max norm: %e <= %e",
424 iteration_summary.gradient_max_norm,
425 options.gradient_tolerance);
426 summary->termination_type = CONVERGENCE;
427 VLOG_IF(1, is_not_silent) << "Terminating: " << summary->message;
428 break;
429 }
430
431 const double absolute_function_tolerance =
432 options.function_tolerance * previous_state.cost;
433 if (fabs(iteration_summary.cost_change) <= absolute_function_tolerance) {
434 summary->message =
435 StringPrintf("Function tolerance reached. "
436 "|cost_change|/cost: %e <= %e",
437 fabs(iteration_summary.cost_change) /
438 previous_state.cost,
439 options.function_tolerance);
440 summary->termination_type = CONVERGENCE;
441 VLOG_IF(1, is_not_silent) << "Terminating: " << summary->message;
442 break;
443 }
444 }
445}
446
447} // namespace internal
448} // namespace ceres