blob: 977b69d5319ea102177a5693ebc3774ab60b3eb2 [file] [log] [blame]
Austin Schuh70cc9552019-01-21 19:46:48 -08001// Ceres Solver - A fast non-linear least squares minimizer
2// Copyright 2017 Google Inc. All rights reserved.
3// http://ceres-solver.org/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Author: sameeragarwal@google.com (Sameer Agarwal)
30//
31// The National Institute of Standards and Technology has released a
32// set of problems to test non-linear least squares solvers.
33//
34// More information about the background on these problems and
35// suggested evaluation methodology can be found at:
36//
37// http://www.itl.nist.gov/div898/strd/nls/nls_info.shtml
38//
39// The problem data themselves can be found at
40//
41// http://www.itl.nist.gov/div898/strd/nls/nls_main.shtml
42//
43// The problems are divided into three levels of difficulty, Easy,
44// Medium and Hard. For each problem there are two starting guesses,
45// the first one far away from the global minimum and the second
46// closer to it.
47//
48// A problem is considered successfully solved, if every components of
49// the solution matches the globally optimal solution in at least 4
50// digits or more.
51//
52// This dataset was used for an evaluation of Non-linear least squares
53// solvers:
54//
55// P. F. Mondragon & B. Borchers, A Comparison of Nonlinear Regression
56// Codes, Journal of Modern Applied Statistical Methods, 4(1):343-351,
57// 2005.
58//
59// The results from Mondragon & Borchers can be summarized as
60// Excel Gnuplot GaussFit HBN MinPack
61// Average LRE 2.3 4.3 4.0 6.8 4.4
62// Winner 1 5 12 29 12
63//
64// Where the row Winner counts, the number of problems for which the
65// solver had the highest LRE.
66
67// In this file, we implement the same evaluation methodology using
68// Ceres. Currently using Levenberg-Marquardt with DENSE_QR, we get
69//
70// Excel Gnuplot GaussFit HBN MinPack Ceres
71// Average LRE 2.3 4.3 4.0 6.8 4.4 9.4
72// Winner 0 0 5 11 2 41
73
74#include <fstream>
75#include <iostream>
76#include <iterator>
77
78#include "Eigen/Core"
79#include "ceres/ceres.h"
80#include "ceres/tiny_solver.h"
81#include "ceres/tiny_solver_cost_function_adapter.h"
82#include "gflags/gflags.h"
83#include "glog/logging.h"
84
85DEFINE_bool(use_tiny_solver, false, "Use TinySolver instead of Ceres::Solver");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -080086DEFINE_string(nist_data_dir,
87 "",
88 "Directory containing the NIST non-linear regression examples");
89DEFINE_string(minimizer,
90 "trust_region",
Austin Schuh70cc9552019-01-21 19:46:48 -080091 "Minimizer type to use, choices are: line_search & trust_region");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -080092DEFINE_string(trust_region_strategy,
93 "levenberg_marquardt",
Austin Schuh70cc9552019-01-21 19:46:48 -080094 "Options are: levenberg_marquardt, dogleg");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -080095DEFINE_string(dogleg,
96 "traditional_dogleg",
Austin Schuh70cc9552019-01-21 19:46:48 -080097 "Options are: traditional_dogleg, subspace_dogleg");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -080098DEFINE_string(linear_solver,
99 "dense_qr",
100 "Options are: sparse_cholesky, dense_qr, dense_normal_cholesky "
101 "and cgnr");
102DEFINE_string(preconditioner, "jacobi", "Options are: identity, jacobi");
103DEFINE_string(line_search,
104 "wolfe",
Austin Schuh70cc9552019-01-21 19:46:48 -0800105 "Line search algorithm to use, choices are: armijo and wolfe.");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800106DEFINE_string(line_search_direction,
107 "lbfgs",
Austin Schuh70cc9552019-01-21 19:46:48 -0800108 "Line search direction algorithm to use, choices: lbfgs, bfgs");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800109DEFINE_int32(max_line_search_iterations,
110 20,
Austin Schuh70cc9552019-01-21 19:46:48 -0800111 "Maximum number of iterations for each line search.");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800112DEFINE_int32(max_line_search_restarts,
113 10,
Austin Schuh70cc9552019-01-21 19:46:48 -0800114 "Maximum number of restarts of line search direction algorithm.");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800115DEFINE_string(line_search_interpolation,
116 "cubic",
117 "Degree of polynomial aproximation in line search, choices are: "
118 "bisection, quadratic & cubic.");
119DEFINE_int32(lbfgs_rank,
120 20,
Austin Schuh70cc9552019-01-21 19:46:48 -0800121 "Rank of L-BFGS inverse Hessian approximation in line search.");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800122DEFINE_bool(approximate_eigenvalue_bfgs_scaling,
123 false,
Austin Schuh70cc9552019-01-21 19:46:48 -0800124 "Use approximate eigenvalue scaling in (L)BFGS line search.");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800125DEFINE_double(sufficient_decrease,
126 1.0e-4,
Austin Schuh70cc9552019-01-21 19:46:48 -0800127 "Line search Armijo sufficient (function) decrease factor.");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800128DEFINE_double(sufficient_curvature_decrease,
129 0.9,
Austin Schuh70cc9552019-01-21 19:46:48 -0800130 "Line search Wolfe sufficient curvature decrease factor.");
131DEFINE_int32(num_iterations, 10000, "Number of iterations");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800132DEFINE_bool(nonmonotonic_steps,
133 false,
134 "Trust region algorithm can use nonmonotic steps");
Austin Schuh70cc9552019-01-21 19:46:48 -0800135DEFINE_double(initial_trust_region_radius, 1e4, "Initial trust region radius");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800136DEFINE_bool(use_numeric_diff,
137 false,
Austin Schuh70cc9552019-01-21 19:46:48 -0800138 "Use numeric differentiation instead of automatic "
139 "differentiation.");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800140DEFINE_string(numeric_diff_method,
141 "ridders",
142 "When using numeric differentiation, selects algorithm. Options "
143 "are: central, forward, ridders.");
144DEFINE_double(ridders_step_size,
145 1e-9,
146 "Initial step size for Ridders numeric differentiation.");
147DEFINE_int32(ridders_extrapolations,
148 3,
149 "Maximal number of Ridders extrapolations.");
Austin Schuh70cc9552019-01-21 19:46:48 -0800150
151namespace ceres {
152namespace examples {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800153namespace {
Austin Schuh70cc9552019-01-21 19:46:48 -0800154
155using Eigen::Dynamic;
156using Eigen::RowMajor;
157typedef Eigen::Matrix<double, Dynamic, 1> Vector;
158typedef Eigen::Matrix<double, Dynamic, Dynamic, RowMajor> Matrix;
159
160using std::atof;
161using std::atoi;
162using std::cout;
163using std::ifstream;
164using std::string;
165using std::vector;
166
167void SplitStringUsingChar(const string& full,
168 const char delim,
169 vector<string>* result) {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800170 std::back_insert_iterator<vector<string>> it(*result);
Austin Schuh70cc9552019-01-21 19:46:48 -0800171
172 const char* p = full.data();
173 const char* end = p + full.size();
174 while (p != end) {
175 if (*p == delim) {
176 ++p;
177 } else {
178 const char* start = p;
179 while (++p != end && *p != delim) {
180 // Skip to the next occurence of the delimiter.
181 }
182 *it++ = string(start, p - start);
183 }
184 }
185}
186
187bool GetAndSplitLine(ifstream& ifs, vector<string>* pieces) {
188 pieces->clear();
189 char buf[256];
190 ifs.getline(buf, 256);
191 SplitStringUsingChar(string(buf), ' ', pieces);
192 return true;
193}
194
195void SkipLines(ifstream& ifs, int num_lines) {
196 char buf[256];
197 for (int i = 0; i < num_lines; ++i) {
198 ifs.getline(buf, 256);
199 }
200}
201
202class NISTProblem {
203 public:
204 explicit NISTProblem(const string& filename) {
205 ifstream ifs(filename.c_str(), ifstream::in);
206 CHECK(ifs) << "Unable to open : " << filename;
207
208 vector<string> pieces;
209 SkipLines(ifs, 24);
210 GetAndSplitLine(ifs, &pieces);
211 const int kNumResponses = atoi(pieces[1].c_str());
212
213 GetAndSplitLine(ifs, &pieces);
214 const int kNumPredictors = atoi(pieces[0].c_str());
215
216 GetAndSplitLine(ifs, &pieces);
217 const int kNumObservations = atoi(pieces[0].c_str());
218
219 SkipLines(ifs, 4);
220 GetAndSplitLine(ifs, &pieces);
221 const int kNumParameters = atoi(pieces[0].c_str());
222 SkipLines(ifs, 8);
223
224 // Get the first line of initial and final parameter values to
225 // determine the number of tries.
226 GetAndSplitLine(ifs, &pieces);
227 const int kNumTries = pieces.size() - 4;
228
229 predictor_.resize(kNumObservations, kNumPredictors);
230 response_.resize(kNumObservations, kNumResponses);
231 initial_parameters_.resize(kNumTries, kNumParameters);
232 final_parameters_.resize(1, kNumParameters);
233
234 // Parse the line for parameter b1.
235 int parameter_id = 0;
236 for (int i = 0; i < kNumTries; ++i) {
237 initial_parameters_(i, parameter_id) = atof(pieces[i + 2].c_str());
238 }
239 final_parameters_(0, parameter_id) = atof(pieces[2 + kNumTries].c_str());
240
241 // Parse the remaining parameter lines.
242 for (int parameter_id = 1; parameter_id < kNumParameters; ++parameter_id) {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800243 GetAndSplitLine(ifs, &pieces);
244 // b2, b3, ....
245 for (int i = 0; i < kNumTries; ++i) {
246 initial_parameters_(i, parameter_id) = atof(pieces[i + 2].c_str());
247 }
248 final_parameters_(0, parameter_id) = atof(pieces[2 + kNumTries].c_str());
Austin Schuh70cc9552019-01-21 19:46:48 -0800249 }
250
251 // Certfied cost
252 SkipLines(ifs, 1);
253 GetAndSplitLine(ifs, &pieces);
254 certified_cost_ = atof(pieces[4].c_str()) / 2.0;
255
256 // Read the observations.
257 SkipLines(ifs, 18 - kNumParameters);
258 for (int i = 0; i < kNumObservations; ++i) {
259 GetAndSplitLine(ifs, &pieces);
260 // Response.
261 for (int j = 0; j < kNumResponses; ++j) {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800262 response_(i, j) = atof(pieces[j].c_str());
Austin Schuh70cc9552019-01-21 19:46:48 -0800263 }
264
265 // Predictor variables.
266 for (int j = 0; j < kNumPredictors; ++j) {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800267 predictor_(i, j) = atof(pieces[j + kNumResponses].c_str());
Austin Schuh70cc9552019-01-21 19:46:48 -0800268 }
269 }
270 }
271
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800272 Matrix initial_parameters(int start) const {
273 return initial_parameters_.row(start);
274 } // NOLINT
275 Matrix final_parameters() const { return final_parameters_; }
276 Matrix predictor() const { return predictor_; }
277 Matrix response() const { return response_; }
278 int predictor_size() const { return predictor_.cols(); }
279 int num_observations() const { return predictor_.rows(); }
280 int response_size() const { return response_.cols(); }
281 int num_parameters() const { return initial_parameters_.cols(); }
282 int num_starts() const { return initial_parameters_.rows(); }
283 double certified_cost() const { return certified_cost_; }
Austin Schuh70cc9552019-01-21 19:46:48 -0800284
285 private:
286 Matrix predictor_;
287 Matrix response_;
288 Matrix initial_parameters_;
289 Matrix final_parameters_;
290 double certified_cost_;
291};
292
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800293#define NIST_BEGIN(CostFunctionName) \
294 struct CostFunctionName { \
295 CostFunctionName(const double* const x, \
296 const double* const y, \
297 const int n) \
298 : x_(x), y_(y), n_(n) {} \
299 const double* x_; \
300 const double* y_; \
301 const int n_; \
302 template <typename T> \
303 bool operator()(const T* const b, T* residual) const { \
304 for (int i = 0; i < n_; ++i) { \
305 const T x(x_[i]); \
Austin Schuh70cc9552019-01-21 19:46:48 -0800306 residual[i] = y_[i] - (
307
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800308// clang-format off
309
Austin Schuh70cc9552019-01-21 19:46:48 -0800310#define NIST_END ); } return true; }};
311
312// y = b1 * (b2+x)**(-1/b3) + e
313NIST_BEGIN(Bennet5)
314 b[0] * pow(b[1] + x, -1.0 / b[2])
315NIST_END
316
317// y = b1*(1-exp[-b2*x]) + e
318NIST_BEGIN(BoxBOD)
319 b[0] * (1.0 - exp(-b[1] * x))
320NIST_END
321
322// y = exp[-b1*x]/(b2+b3*x) + e
323NIST_BEGIN(Chwirut)
324 exp(-b[0] * x) / (b[1] + b[2] * x)
325NIST_END
326
327// y = b1*x**b2 + e
328NIST_BEGIN(DanWood)
329 b[0] * pow(x, b[1])
330NIST_END
331
332// y = b1*exp( -b2*x ) + b3*exp( -(x-b4)**2 / b5**2 )
333// + b6*exp( -(x-b7)**2 / b8**2 ) + e
334NIST_BEGIN(Gauss)
335 b[0] * exp(-b[1] * x) +
336 b[2] * exp(-pow((x - b[3])/b[4], 2)) +
337 b[5] * exp(-pow((x - b[6])/b[7], 2))
338NIST_END
339
340// y = b1*exp(-b2*x) + b3*exp(-b4*x) + b5*exp(-b6*x) + e
341NIST_BEGIN(Lanczos)
342 b[0] * exp(-b[1] * x) + b[2] * exp(-b[3] * x) + b[4] * exp(-b[5] * x)
343NIST_END
344
345// y = (b1+b2*x+b3*x**2+b4*x**3) /
346// (1+b5*x+b6*x**2+b7*x**3) + e
347NIST_BEGIN(Hahn1)
348 (b[0] + b[1] * x + b[2] * x * x + b[3] * x * x * x) /
349 (1.0 + b[4] * x + b[5] * x * x + b[6] * x * x * x)
350NIST_END
351
352// y = (b1 + b2*x + b3*x**2) /
353// (1 + b4*x + b5*x**2) + e
354NIST_BEGIN(Kirby2)
355 (b[0] + b[1] * x + b[2] * x * x) /
356 (1.0 + b[3] * x + b[4] * x * x)
357NIST_END
358
359// y = b1*(x**2+x*b2) / (x**2+x*b3+b4) + e
360NIST_BEGIN(MGH09)
361 b[0] * (x * x + x * b[1]) / (x * x + x * b[2] + b[3])
362NIST_END
363
364// y = b1 * exp[b2/(x+b3)] + e
365NIST_BEGIN(MGH10)
366 b[0] * exp(b[1] / (x + b[2]))
367NIST_END
368
369// y = b1 + b2*exp[-x*b4] + b3*exp[-x*b5]
370NIST_BEGIN(MGH17)
371 b[0] + b[1] * exp(-x * b[3]) + b[2] * exp(-x * b[4])
372NIST_END
373
374// y = b1*(1-exp[-b2*x]) + e
375NIST_BEGIN(Misra1a)
376 b[0] * (1.0 - exp(-b[1] * x))
377NIST_END
378
379// y = b1 * (1-(1+b2*x/2)**(-2)) + e
380NIST_BEGIN(Misra1b)
381 b[0] * (1.0 - 1.0/ ((1.0 + b[1] * x / 2.0) * (1.0 + b[1] * x / 2.0))) // NOLINT
382NIST_END
383
384// y = b1 * (1-(1+2*b2*x)**(-.5)) + e
385NIST_BEGIN(Misra1c)
386 b[0] * (1.0 - pow(1.0 + 2.0 * b[1] * x, -0.5))
387NIST_END
388
389// y = b1*b2*x*((1+b2*x)**(-1)) + e
390NIST_BEGIN(Misra1d)
391 b[0] * b[1] * x / (1.0 + b[1] * x)
392NIST_END
393
394const double kPi = 3.141592653589793238462643383279;
395// pi = 3.141592653589793238462643383279E0
396// y = b1 - b2*x - arctan[b3/(x-b4)]/pi + e
397NIST_BEGIN(Roszman1)
398 b[0] - b[1] * x - atan2(b[2], (x - b[3])) / kPi
399NIST_END
400
401// y = b1 / (1+exp[b2-b3*x]) + e
402NIST_BEGIN(Rat42)
403 b[0] / (1.0 + exp(b[1] - b[2] * x))
404NIST_END
405
406// y = b1 / ((1+exp[b2-b3*x])**(1/b4)) + e
407NIST_BEGIN(Rat43)
408 b[0] / pow(1.0 + exp(b[1] - b[2] * x), 1.0 / b[3])
409NIST_END
410
411// y = (b1 + b2*x + b3*x**2 + b4*x**3) /
412// (1 + b5*x + b6*x**2 + b7*x**3) + e
413NIST_BEGIN(Thurber)
414 (b[0] + b[1] * x + b[2] * x * x + b[3] * x * x * x) /
415 (1.0 + b[4] * x + b[5] * x * x + b[6] * x * x * x)
416NIST_END
417
418// y = b1 + b2*cos( 2*pi*x/12 ) + b3*sin( 2*pi*x/12 )
419// + b5*cos( 2*pi*x/b4 ) + b6*sin( 2*pi*x/b4 )
420// + b8*cos( 2*pi*x/b7 ) + b9*sin( 2*pi*x/b7 ) + e
421NIST_BEGIN(ENSO)
422 b[0] + b[1] * cos(2.0 * kPi * x / 12.0) +
423 b[2] * sin(2.0 * kPi * x / 12.0) +
424 b[4] * cos(2.0 * kPi * x / b[3]) +
425 b[5] * sin(2.0 * kPi * x / b[3]) +
426 b[7] * cos(2.0 * kPi * x / b[6]) +
427 b[8] * sin(2.0 * kPi * x / b[6])
428NIST_END
429
430// y = (b1/b2) * exp[-0.5*((x-b3)/b2)**2] + e
431NIST_BEGIN(Eckerle4)
432 b[0] / b[1] * exp(-0.5 * pow((x - b[2])/b[1], 2))
433NIST_END
434
435struct Nelson {
436 public:
437 Nelson(const double* const x, const double* const y, const int n)
438 : x_(x), y_(y), n_(n) {}
439
440 template <typename T>
441 bool operator()(const T* const b, T* residual) const {
442 // log[y] = b1 - b2*x1 * exp[-b3*x2] + e
443 for (int i = 0; i < n_; ++i) {
444 residual[i] = log(y_[i]) - (b[0] - b[1] * x_[2 * i] * exp(-b[2] * x_[2 * i + 1]));
445 }
446 return true;
447 }
448
449 private:
450 const double* x_;
451 const double* y_;
452 const int n_;
453};
454
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800455// clang-format on
456
Austin Schuh70cc9552019-01-21 19:46:48 -0800457static void SetNumericDiffOptions(ceres::NumericDiffOptions* options) {
458 options->max_num_ridders_extrapolations = FLAGS_ridders_extrapolations;
459 options->ridders_relative_initial_step_size = FLAGS_ridders_step_size;
460}
461
462void SetMinimizerOptions(ceres::Solver::Options* options) {
463 CHECK(
464 ceres::StringToMinimizerType(FLAGS_minimizer, &options->minimizer_type));
465 CHECK(ceres::StringToLinearSolverType(FLAGS_linear_solver,
466 &options->linear_solver_type));
467 CHECK(ceres::StringToPreconditionerType(FLAGS_preconditioner,
468 &options->preconditioner_type));
469 CHECK(ceres::StringToTrustRegionStrategyType(
470 FLAGS_trust_region_strategy, &options->trust_region_strategy_type));
471 CHECK(ceres::StringToDoglegType(FLAGS_dogleg, &options->dogleg_type));
472 CHECK(ceres::StringToLineSearchDirectionType(
473 FLAGS_line_search_direction, &options->line_search_direction_type));
474 CHECK(ceres::StringToLineSearchType(FLAGS_line_search,
475 &options->line_search_type));
476 CHECK(ceres::StringToLineSearchInterpolationType(
477 FLAGS_line_search_interpolation,
478 &options->line_search_interpolation_type));
479
480 options->max_num_iterations = FLAGS_num_iterations;
481 options->use_nonmonotonic_steps = FLAGS_nonmonotonic_steps;
482 options->initial_trust_region_radius = FLAGS_initial_trust_region_radius;
483 options->max_lbfgs_rank = FLAGS_lbfgs_rank;
484 options->line_search_sufficient_function_decrease = FLAGS_sufficient_decrease;
485 options->line_search_sufficient_curvature_decrease =
486 FLAGS_sufficient_curvature_decrease;
487 options->max_num_line_search_step_size_iterations =
488 FLAGS_max_line_search_iterations;
489 options->max_num_line_search_direction_restarts =
490 FLAGS_max_line_search_restarts;
491 options->use_approximate_eigenvalue_bfgs_scaling =
492 FLAGS_approximate_eigenvalue_bfgs_scaling;
493 options->function_tolerance = std::numeric_limits<double>::epsilon();
494 options->gradient_tolerance = std::numeric_limits<double>::epsilon();
495 options->parameter_tolerance = std::numeric_limits<double>::epsilon();
496}
497
498string JoinPath(const string& dirname, const string& basename) {
499#ifdef _WIN32
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800500 static const char separator = '\\';
Austin Schuh70cc9552019-01-21 19:46:48 -0800501#else
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800502 static const char separator = '/';
Austin Schuh70cc9552019-01-21 19:46:48 -0800503#endif // _WIN32
504
505 if ((!basename.empty() && basename[0] == separator) || dirname.empty()) {
506 return basename;
507 } else if (dirname[dirname.size() - 1] == separator) {
508 return dirname + basename;
509 } else {
510 return dirname + string(&separator, 1) + basename;
511 }
512}
513
514template <typename Model, int num_parameters>
515CostFunction* CreateCostFunction(const Matrix& predictor,
516 const Matrix& response,
517 const int num_observations) {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800518 Model* model = new Model(predictor.data(), response.data(), num_observations);
Austin Schuh70cc9552019-01-21 19:46:48 -0800519 ceres::CostFunction* cost_function = NULL;
520 if (FLAGS_use_numeric_diff) {
521 ceres::NumericDiffOptions options;
522 SetNumericDiffOptions(&options);
523 if (FLAGS_numeric_diff_method == "central") {
524 cost_function = new NumericDiffCostFunction<Model,
525 ceres::CENTRAL,
526 ceres::DYNAMIC,
527 num_parameters>(
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800528 model, ceres::TAKE_OWNERSHIP, num_observations, options);
Austin Schuh70cc9552019-01-21 19:46:48 -0800529 } else if (FLAGS_numeric_diff_method == "forward") {
530 cost_function = new NumericDiffCostFunction<Model,
531 ceres::FORWARD,
532 ceres::DYNAMIC,
533 num_parameters>(
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800534 model, ceres::TAKE_OWNERSHIP, num_observations, options);
Austin Schuh70cc9552019-01-21 19:46:48 -0800535 } else if (FLAGS_numeric_diff_method == "ridders") {
536 cost_function = new NumericDiffCostFunction<Model,
537 ceres::RIDDERS,
538 ceres::DYNAMIC,
539 num_parameters>(
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800540 model, ceres::TAKE_OWNERSHIP, num_observations, options);
Austin Schuh70cc9552019-01-21 19:46:48 -0800541 } else {
542 LOG(ERROR) << "Invalid numeric diff method specified";
543 return 0;
544 }
545 } else {
546 cost_function =
547 new ceres::AutoDiffCostFunction<Model, ceres::DYNAMIC, num_parameters>(
548 model, num_observations);
549 }
550 return cost_function;
551}
552
553double ComputeLRE(const Matrix& expected, const Matrix& actual) {
554 // Compute the LRE by comparing each component of the solution
555 // with the ground truth, and taking the minimum.
556 const double kMaxNumSignificantDigits = 11;
557 double log_relative_error = kMaxNumSignificantDigits + 1;
558 for (int i = 0; i < expected.cols(); ++i) {
559 const double tmp_lre = -std::log10(std::fabs(expected(i) - actual(i)) /
560 std::fabs(expected(i)));
561 // The maximum LRE is capped at 11 - the precision at which the
562 // ground truth is known.
563 //
564 // The minimum LRE is capped at 0 - no digits match between the
565 // computed solution and the ground truth.
566 log_relative_error =
567 std::min(log_relative_error,
568 std::max(0.0, std::min(kMaxNumSignificantDigits, tmp_lre)));
569 }
570 return log_relative_error;
571}
572
573template <typename Model, int num_parameters>
574int RegressionDriver(const string& filename) {
575 NISTProblem nist_problem(JoinPath(FLAGS_nist_data_dir, filename));
576 CHECK_EQ(num_parameters, nist_problem.num_parameters());
577
578 Matrix predictor = nist_problem.predictor();
579 Matrix response = nist_problem.response();
580 Matrix final_parameters = nist_problem.final_parameters();
581
582 printf("%s\n", filename.c_str());
583
584 // Each NIST problem comes with multiple starting points, so we
585 // construct the problem from scratch for each case and solve it.
586 int num_success = 0;
587 for (int start = 0; start < nist_problem.num_starts(); ++start) {
588 Matrix initial_parameters = nist_problem.initial_parameters(start);
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800589 ceres::CostFunction* cost_function =
590 CreateCostFunction<Model, num_parameters>(
591 predictor, response, nist_problem.num_observations());
Austin Schuh70cc9552019-01-21 19:46:48 -0800592
593 double initial_cost;
594 double final_cost;
595
596 if (!FLAGS_use_tiny_solver) {
597 ceres::Problem problem;
598 problem.AddResidualBlock(cost_function, NULL, initial_parameters.data());
599 ceres::Solver::Summary summary;
600 ceres::Solver::Options options;
601 SetMinimizerOptions(&options);
602 Solve(options, &problem, &summary);
603 initial_cost = summary.initial_cost;
604 final_cost = summary.final_cost;
605 } else {
606 ceres::TinySolverCostFunctionAdapter<Eigen::Dynamic, num_parameters> cfa(
607 *cost_function);
608 typedef ceres::TinySolver<
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800609 ceres::TinySolverCostFunctionAdapter<Eigen::Dynamic, num_parameters>>
Austin Schuh70cc9552019-01-21 19:46:48 -0800610 Solver;
611 Solver solver;
612 solver.options.max_num_iterations = FLAGS_num_iterations;
613 solver.options.gradient_tolerance =
614 std::numeric_limits<double>::epsilon();
615 solver.options.parameter_tolerance =
616 std::numeric_limits<double>::epsilon();
617
618 Eigen::Matrix<double, num_parameters, 1> x;
619 x = initial_parameters.transpose();
620 typename Solver::Summary summary = solver.Solve(cfa, &x);
621 initial_parameters = x;
622 initial_cost = summary.initial_cost;
623 final_cost = summary.final_cost;
624 delete cost_function;
625 }
626
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800627 const double log_relative_error =
628 ComputeLRE(nist_problem.final_parameters(), initial_parameters);
Austin Schuh70cc9552019-01-21 19:46:48 -0800629 const int kMinNumMatchingDigits = 4;
630 if (log_relative_error > kMinNumMatchingDigits) {
631 ++num_success;
632 }
633
634 printf(
635 "start: %d status: %s lre: %4.1f initial cost: %e final cost:%e "
636 "certified cost: %e\n",
637 start + 1,
638 log_relative_error < kMinNumMatchingDigits ? "FAILURE" : "SUCCESS",
639 log_relative_error,
640 initial_cost,
641 final_cost,
642 nist_problem.certified_cost());
643 }
644 return num_success;
645}
646
Austin Schuh70cc9552019-01-21 19:46:48 -0800647void SolveNISTProblems() {
648 if (FLAGS_nist_data_dir.empty()) {
649 LOG(FATAL) << "Must specify the directory containing the NIST problems";
650 }
651
652 cout << "Lower Difficulty\n";
653 int easy_success = 0;
654 easy_success += RegressionDriver<Misra1a, 2>("Misra1a.dat");
655 easy_success += RegressionDriver<Chwirut, 3>("Chwirut1.dat");
656 easy_success += RegressionDriver<Chwirut, 3>("Chwirut2.dat");
657 easy_success += RegressionDriver<Lanczos, 6>("Lanczos3.dat");
658 easy_success += RegressionDriver<Gauss, 8>("Gauss1.dat");
659 easy_success += RegressionDriver<Gauss, 8>("Gauss2.dat");
660 easy_success += RegressionDriver<DanWood, 2>("DanWood.dat");
661 easy_success += RegressionDriver<Misra1b, 2>("Misra1b.dat");
662
663 cout << "\nMedium Difficulty\n";
664 int medium_success = 0;
665 medium_success += RegressionDriver<Kirby2, 5>("Kirby2.dat");
666 medium_success += RegressionDriver<Hahn1, 7>("Hahn1.dat");
667 medium_success += RegressionDriver<Nelson, 3>("Nelson.dat");
668 medium_success += RegressionDriver<MGH17, 5>("MGH17.dat");
669 medium_success += RegressionDriver<Lanczos, 6>("Lanczos1.dat");
670 medium_success += RegressionDriver<Lanczos, 6>("Lanczos2.dat");
671 medium_success += RegressionDriver<Gauss, 8>("Gauss3.dat");
672 medium_success += RegressionDriver<Misra1c, 2>("Misra1c.dat");
673 medium_success += RegressionDriver<Misra1d, 2>("Misra1d.dat");
674 medium_success += RegressionDriver<Roszman1, 4>("Roszman1.dat");
675 medium_success += RegressionDriver<ENSO, 9>("ENSO.dat");
676
677 cout << "\nHigher Difficulty\n";
678 int hard_success = 0;
679 hard_success += RegressionDriver<MGH09, 4>("MGH09.dat");
680 hard_success += RegressionDriver<Thurber, 7>("Thurber.dat");
681 hard_success += RegressionDriver<BoxBOD, 2>("BoxBOD.dat");
682 hard_success += RegressionDriver<Rat42, 3>("Rat42.dat");
683 hard_success += RegressionDriver<MGH10, 3>("MGH10.dat");
684 hard_success += RegressionDriver<Eckerle4, 3>("Eckerle4.dat");
685 hard_success += RegressionDriver<Rat43, 4>("Rat43.dat");
686 hard_success += RegressionDriver<Bennet5, 3>("Bennett5.dat");
687
688 cout << "\n";
689 cout << "Easy : " << easy_success << "/16\n";
690 cout << "Medium : " << medium_success << "/22\n";
691 cout << "Hard : " << hard_success << "/16\n";
692 cout << "Total : " << easy_success + medium_success + hard_success
693 << "/54\n";
694}
695
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800696} // namespace
Austin Schuh70cc9552019-01-21 19:46:48 -0800697} // namespace examples
698} // namespace ceres
699
700int main(int argc, char** argv) {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800701 GFLAGS_NAMESPACE::ParseCommandLineFlags(&argc, &argv, true);
Austin Schuh70cc9552019-01-21 19:46:48 -0800702 google::InitGoogleLogging(argv[0]);
703 ceres::examples::SolveNISTProblems();
704 return 0;
705}