blob: b92c918b102b31cff58587bac396036260d8b815 [file] [log] [blame]
Austin Schuh70cc9552019-01-21 19:46:48 -08001// Ceres Solver - A fast non-linear least squares minimizer
Austin Schuh3de38b02024-06-25 18:25:10 -07002// Copyright 2023 Google Inc. All rights reserved.
Austin Schuh70cc9552019-01-21 19:46:48 -08003// http://ceres-solver.org/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Author: sameeragarwal@google.com (Sameer Agarwal)
30//
31// The National Institute of Standards and Technology has released a
32// set of problems to test non-linear least squares solvers.
33//
34// More information about the background on these problems and
35// suggested evaluation methodology can be found at:
36//
37// http://www.itl.nist.gov/div898/strd/nls/nls_info.shtml
38//
39// The problem data themselves can be found at
40//
41// http://www.itl.nist.gov/div898/strd/nls/nls_main.shtml
42//
43// The problems are divided into three levels of difficulty, Easy,
44// Medium and Hard. For each problem there are two starting guesses,
45// the first one far away from the global minimum and the second
46// closer to it.
47//
48// A problem is considered successfully solved, if every components of
49// the solution matches the globally optimal solution in at least 4
50// digits or more.
51//
52// This dataset was used for an evaluation of Non-linear least squares
53// solvers:
54//
55// P. F. Mondragon & B. Borchers, A Comparison of Nonlinear Regression
56// Codes, Journal of Modern Applied Statistical Methods, 4(1):343-351,
57// 2005.
58//
59// The results from Mondragon & Borchers can be summarized as
60// Excel Gnuplot GaussFit HBN MinPack
61// Average LRE 2.3 4.3 4.0 6.8 4.4
62// Winner 1 5 12 29 12
63//
64// Where the row Winner counts, the number of problems for which the
65// solver had the highest LRE.
66
67// In this file, we implement the same evaluation methodology using
68// Ceres. Currently using Levenberg-Marquardt with DENSE_QR, we get
69//
70// Excel Gnuplot GaussFit HBN MinPack Ceres
71// Average LRE 2.3 4.3 4.0 6.8 4.4 9.4
72// Winner 0 0 5 11 2 41
73
Austin Schuh3de38b02024-06-25 18:25:10 -070074#include <cstdlib>
Austin Schuh70cc9552019-01-21 19:46:48 -080075#include <fstream>
76#include <iostream>
77#include <iterator>
Austin Schuh3de38b02024-06-25 18:25:10 -070078#include <string>
79#include <vector>
Austin Schuh70cc9552019-01-21 19:46:48 -080080
81#include "Eigen/Core"
82#include "ceres/ceres.h"
83#include "ceres/tiny_solver.h"
84#include "ceres/tiny_solver_cost_function_adapter.h"
85#include "gflags/gflags.h"
86#include "glog/logging.h"
87
88DEFINE_bool(use_tiny_solver, false, "Use TinySolver instead of Ceres::Solver");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -080089DEFINE_string(nist_data_dir,
90 "",
91 "Directory containing the NIST non-linear regression examples");
92DEFINE_string(minimizer,
93 "trust_region",
Austin Schuh70cc9552019-01-21 19:46:48 -080094 "Minimizer type to use, choices are: line_search & trust_region");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -080095DEFINE_string(trust_region_strategy,
96 "levenberg_marquardt",
Austin Schuh70cc9552019-01-21 19:46:48 -080097 "Options are: levenberg_marquardt, dogleg");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -080098DEFINE_string(dogleg,
99 "traditional_dogleg",
Austin Schuh70cc9552019-01-21 19:46:48 -0800100 "Options are: traditional_dogleg, subspace_dogleg");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800101DEFINE_string(linear_solver,
102 "dense_qr",
103 "Options are: sparse_cholesky, dense_qr, dense_normal_cholesky "
104 "and cgnr");
Austin Schuh3de38b02024-06-25 18:25:10 -0700105DEFINE_string(dense_linear_algebra_library,
106 "eigen",
107 "Options are: eigen, lapack, and cuda.");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800108DEFINE_string(preconditioner, "jacobi", "Options are: identity, jacobi");
109DEFINE_string(line_search,
110 "wolfe",
Austin Schuh70cc9552019-01-21 19:46:48 -0800111 "Line search algorithm to use, choices are: armijo and wolfe.");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800112DEFINE_string(line_search_direction,
113 "lbfgs",
Austin Schuh70cc9552019-01-21 19:46:48 -0800114 "Line search direction algorithm to use, choices: lbfgs, bfgs");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800115DEFINE_int32(max_line_search_iterations,
116 20,
Austin Schuh70cc9552019-01-21 19:46:48 -0800117 "Maximum number of iterations for each line search.");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800118DEFINE_int32(max_line_search_restarts,
119 10,
Austin Schuh70cc9552019-01-21 19:46:48 -0800120 "Maximum number of restarts of line search direction algorithm.");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800121DEFINE_string(line_search_interpolation,
122 "cubic",
Austin Schuh3de38b02024-06-25 18:25:10 -0700123 "Degree of polynomial approximation in line search, choices are: "
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800124 "bisection, quadratic & cubic.");
125DEFINE_int32(lbfgs_rank,
126 20,
Austin Schuh70cc9552019-01-21 19:46:48 -0800127 "Rank of L-BFGS inverse Hessian approximation in line search.");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800128DEFINE_bool(approximate_eigenvalue_bfgs_scaling,
129 false,
Austin Schuh70cc9552019-01-21 19:46:48 -0800130 "Use approximate eigenvalue scaling in (L)BFGS line search.");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800131DEFINE_double(sufficient_decrease,
132 1.0e-4,
Austin Schuh70cc9552019-01-21 19:46:48 -0800133 "Line search Armijo sufficient (function) decrease factor.");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800134DEFINE_double(sufficient_curvature_decrease,
135 0.9,
Austin Schuh70cc9552019-01-21 19:46:48 -0800136 "Line search Wolfe sufficient curvature decrease factor.");
137DEFINE_int32(num_iterations, 10000, "Number of iterations");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800138DEFINE_bool(nonmonotonic_steps,
139 false,
140 "Trust region algorithm can use nonmonotic steps");
Austin Schuh70cc9552019-01-21 19:46:48 -0800141DEFINE_double(initial_trust_region_radius, 1e4, "Initial trust region radius");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800142DEFINE_bool(use_numeric_diff,
143 false,
Austin Schuh70cc9552019-01-21 19:46:48 -0800144 "Use numeric differentiation instead of automatic "
145 "differentiation.");
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800146DEFINE_string(numeric_diff_method,
147 "ridders",
148 "When using numeric differentiation, selects algorithm. Options "
149 "are: central, forward, ridders.");
150DEFINE_double(ridders_step_size,
151 1e-9,
152 "Initial step size for Ridders numeric differentiation.");
153DEFINE_int32(ridders_extrapolations,
154 3,
155 "Maximal number of Ridders extrapolations.");
Austin Schuh70cc9552019-01-21 19:46:48 -0800156
Austin Schuh3de38b02024-06-25 18:25:10 -0700157namespace ceres::examples {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800158namespace {
Austin Schuh70cc9552019-01-21 19:46:48 -0800159
160using Eigen::Dynamic;
161using Eigen::RowMajor;
Austin Schuh3de38b02024-06-25 18:25:10 -0700162using Vector = Eigen::Matrix<double, Dynamic, 1>;
163using Matrix = Eigen::Matrix<double, Dynamic, Dynamic, RowMajor>;
Austin Schuh70cc9552019-01-21 19:46:48 -0800164
Austin Schuh3de38b02024-06-25 18:25:10 -0700165void SplitStringUsingChar(const std::string& full,
Austin Schuh70cc9552019-01-21 19:46:48 -0800166 const char delim,
Austin Schuh3de38b02024-06-25 18:25:10 -0700167 std::vector<std::string>* result) {
168 std::back_insert_iterator<std::vector<std::string>> it(*result);
Austin Schuh70cc9552019-01-21 19:46:48 -0800169
170 const char* p = full.data();
171 const char* end = p + full.size();
172 while (p != end) {
173 if (*p == delim) {
174 ++p;
175 } else {
176 const char* start = p;
177 while (++p != end && *p != delim) {
Austin Schuh3de38b02024-06-25 18:25:10 -0700178 // Skip to the next occurrence of the delimiter.
Austin Schuh70cc9552019-01-21 19:46:48 -0800179 }
Austin Schuh3de38b02024-06-25 18:25:10 -0700180 *it++ = std::string(start, p - start);
Austin Schuh70cc9552019-01-21 19:46:48 -0800181 }
182 }
183}
184
Austin Schuh3de38b02024-06-25 18:25:10 -0700185bool GetAndSplitLine(std::ifstream& ifs, std::vector<std::string>* pieces) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800186 pieces->clear();
187 char buf[256];
188 ifs.getline(buf, 256);
Austin Schuh3de38b02024-06-25 18:25:10 -0700189 SplitStringUsingChar(std::string(buf), ' ', pieces);
Austin Schuh70cc9552019-01-21 19:46:48 -0800190 return true;
191}
192
Austin Schuh3de38b02024-06-25 18:25:10 -0700193void SkipLines(std::ifstream& ifs, int num_lines) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800194 char buf[256];
195 for (int i = 0; i < num_lines; ++i) {
196 ifs.getline(buf, 256);
197 }
198}
199
200class NISTProblem {
201 public:
Austin Schuh3de38b02024-06-25 18:25:10 -0700202 explicit NISTProblem(const std::string& filename) {
203 std::ifstream ifs(filename.c_str(), std::ifstream::in);
Austin Schuh70cc9552019-01-21 19:46:48 -0800204 CHECK(ifs) << "Unable to open : " << filename;
205
Austin Schuh3de38b02024-06-25 18:25:10 -0700206 std::vector<std::string> pieces;
Austin Schuh70cc9552019-01-21 19:46:48 -0800207 SkipLines(ifs, 24);
208 GetAndSplitLine(ifs, &pieces);
Austin Schuh3de38b02024-06-25 18:25:10 -0700209 const int kNumResponses = std::atoi(pieces[1].c_str());
Austin Schuh70cc9552019-01-21 19:46:48 -0800210
211 GetAndSplitLine(ifs, &pieces);
Austin Schuh3de38b02024-06-25 18:25:10 -0700212 const int kNumPredictors = std::atoi(pieces[0].c_str());
Austin Schuh70cc9552019-01-21 19:46:48 -0800213
214 GetAndSplitLine(ifs, &pieces);
Austin Schuh3de38b02024-06-25 18:25:10 -0700215 const int kNumObservations = std::atoi(pieces[0].c_str());
Austin Schuh70cc9552019-01-21 19:46:48 -0800216
217 SkipLines(ifs, 4);
218 GetAndSplitLine(ifs, &pieces);
Austin Schuh3de38b02024-06-25 18:25:10 -0700219 const int kNumParameters = std::atoi(pieces[0].c_str());
Austin Schuh70cc9552019-01-21 19:46:48 -0800220 SkipLines(ifs, 8);
221
222 // Get the first line of initial and final parameter values to
223 // determine the number of tries.
224 GetAndSplitLine(ifs, &pieces);
225 const int kNumTries = pieces.size() - 4;
226
227 predictor_.resize(kNumObservations, kNumPredictors);
228 response_.resize(kNumObservations, kNumResponses);
229 initial_parameters_.resize(kNumTries, kNumParameters);
230 final_parameters_.resize(1, kNumParameters);
231
232 // Parse the line for parameter b1.
233 int parameter_id = 0;
234 for (int i = 0; i < kNumTries; ++i) {
Austin Schuh3de38b02024-06-25 18:25:10 -0700235 initial_parameters_(i, parameter_id) = std::atof(pieces[i + 2].c_str());
Austin Schuh70cc9552019-01-21 19:46:48 -0800236 }
Austin Schuh3de38b02024-06-25 18:25:10 -0700237 final_parameters_(0, parameter_id) =
238 std::atof(pieces[2 + kNumTries].c_str());
Austin Schuh70cc9552019-01-21 19:46:48 -0800239
240 // Parse the remaining parameter lines.
241 for (int parameter_id = 1; parameter_id < kNumParameters; ++parameter_id) {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800242 GetAndSplitLine(ifs, &pieces);
243 // b2, b3, ....
244 for (int i = 0; i < kNumTries; ++i) {
Austin Schuh3de38b02024-06-25 18:25:10 -0700245 initial_parameters_(i, parameter_id) = std::atof(pieces[i + 2].c_str());
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800246 }
Austin Schuh3de38b02024-06-25 18:25:10 -0700247 final_parameters_(0, parameter_id) =
248 std::atof(pieces[2 + kNumTries].c_str());
Austin Schuh70cc9552019-01-21 19:46:48 -0800249 }
250
Austin Schuh3de38b02024-06-25 18:25:10 -0700251 // Certified cost
Austin Schuh70cc9552019-01-21 19:46:48 -0800252 SkipLines(ifs, 1);
253 GetAndSplitLine(ifs, &pieces);
Austin Schuh3de38b02024-06-25 18:25:10 -0700254 certified_cost_ = std::atof(pieces[4].c_str()) / 2.0;
Austin Schuh70cc9552019-01-21 19:46:48 -0800255
256 // Read the observations.
257 SkipLines(ifs, 18 - kNumParameters);
258 for (int i = 0; i < kNumObservations; ++i) {
259 GetAndSplitLine(ifs, &pieces);
260 // Response.
261 for (int j = 0; j < kNumResponses; ++j) {
Austin Schuh3de38b02024-06-25 18:25:10 -0700262 response_(i, j) = std::atof(pieces[j].c_str());
Austin Schuh70cc9552019-01-21 19:46:48 -0800263 }
264
265 // Predictor variables.
266 for (int j = 0; j < kNumPredictors; ++j) {
Austin Schuh3de38b02024-06-25 18:25:10 -0700267 predictor_(i, j) = std::atof(pieces[j + kNumResponses].c_str());
Austin Schuh70cc9552019-01-21 19:46:48 -0800268 }
269 }
270 }
271
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800272 Matrix initial_parameters(int start) const {
273 return initial_parameters_.row(start);
274 } // NOLINT
275 Matrix final_parameters() const { return final_parameters_; }
276 Matrix predictor() const { return predictor_; }
277 Matrix response() const { return response_; }
278 int predictor_size() const { return predictor_.cols(); }
279 int num_observations() const { return predictor_.rows(); }
280 int response_size() const { return response_.cols(); }
281 int num_parameters() const { return initial_parameters_.cols(); }
282 int num_starts() const { return initial_parameters_.rows(); }
283 double certified_cost() const { return certified_cost_; }
Austin Schuh70cc9552019-01-21 19:46:48 -0800284
285 private:
286 Matrix predictor_;
287 Matrix response_;
288 Matrix initial_parameters_;
289 Matrix final_parameters_;
290 double certified_cost_;
291};
292
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800293#define NIST_BEGIN(CostFunctionName) \
294 struct CostFunctionName { \
295 CostFunctionName(const double* const x, \
296 const double* const y, \
297 const int n) \
298 : x_(x), y_(y), n_(n) {} \
299 const double* x_; \
300 const double* y_; \
301 const int n_; \
302 template <typename T> \
303 bool operator()(const T* const b, T* residual) const { \
304 for (int i = 0; i < n_; ++i) { \
305 const T x(x_[i]); \
Austin Schuh70cc9552019-01-21 19:46:48 -0800306 residual[i] = y_[i] - (
307
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800308// clang-format off
309
Austin Schuh70cc9552019-01-21 19:46:48 -0800310#define NIST_END ); } return true; }};
311
312// y = b1 * (b2+x)**(-1/b3) + e
313NIST_BEGIN(Bennet5)
314 b[0] * pow(b[1] + x, -1.0 / b[2])
315NIST_END
316
317// y = b1*(1-exp[-b2*x]) + e
318NIST_BEGIN(BoxBOD)
319 b[0] * (1.0 - exp(-b[1] * x))
320NIST_END
321
322// y = exp[-b1*x]/(b2+b3*x) + e
323NIST_BEGIN(Chwirut)
324 exp(-b[0] * x) / (b[1] + b[2] * x)
325NIST_END
326
327// y = b1*x**b2 + e
328NIST_BEGIN(DanWood)
329 b[0] * pow(x, b[1])
330NIST_END
331
332// y = b1*exp( -b2*x ) + b3*exp( -(x-b4)**2 / b5**2 )
333// + b6*exp( -(x-b7)**2 / b8**2 ) + e
334NIST_BEGIN(Gauss)
335 b[0] * exp(-b[1] * x) +
336 b[2] * exp(-pow((x - b[3])/b[4], 2)) +
337 b[5] * exp(-pow((x - b[6])/b[7], 2))
338NIST_END
339
340// y = b1*exp(-b2*x) + b3*exp(-b4*x) + b5*exp(-b6*x) + e
341NIST_BEGIN(Lanczos)
342 b[0] * exp(-b[1] * x) + b[2] * exp(-b[3] * x) + b[4] * exp(-b[5] * x)
343NIST_END
344
345// y = (b1+b2*x+b3*x**2+b4*x**3) /
346// (1+b5*x+b6*x**2+b7*x**3) + e
347NIST_BEGIN(Hahn1)
348 (b[0] + b[1] * x + b[2] * x * x + b[3] * x * x * x) /
349 (1.0 + b[4] * x + b[5] * x * x + b[6] * x * x * x)
350NIST_END
351
352// y = (b1 + b2*x + b3*x**2) /
353// (1 + b4*x + b5*x**2) + e
354NIST_BEGIN(Kirby2)
355 (b[0] + b[1] * x + b[2] * x * x) /
356 (1.0 + b[3] * x + b[4] * x * x)
357NIST_END
358
359// y = b1*(x**2+x*b2) / (x**2+x*b3+b4) + e
360NIST_BEGIN(MGH09)
361 b[0] * (x * x + x * b[1]) / (x * x + x * b[2] + b[3])
362NIST_END
363
364// y = b1 * exp[b2/(x+b3)] + e
365NIST_BEGIN(MGH10)
366 b[0] * exp(b[1] / (x + b[2]))
367NIST_END
368
369// y = b1 + b2*exp[-x*b4] + b3*exp[-x*b5]
370NIST_BEGIN(MGH17)
371 b[0] + b[1] * exp(-x * b[3]) + b[2] * exp(-x * b[4])
372NIST_END
373
374// y = b1*(1-exp[-b2*x]) + e
375NIST_BEGIN(Misra1a)
376 b[0] * (1.0 - exp(-b[1] * x))
377NIST_END
378
379// y = b1 * (1-(1+b2*x/2)**(-2)) + e
380NIST_BEGIN(Misra1b)
381 b[0] * (1.0 - 1.0/ ((1.0 + b[1] * x / 2.0) * (1.0 + b[1] * x / 2.0))) // NOLINT
382NIST_END
383
384// y = b1 * (1-(1+2*b2*x)**(-.5)) + e
385NIST_BEGIN(Misra1c)
386 b[0] * (1.0 - pow(1.0 + 2.0 * b[1] * x, -0.5))
387NIST_END
388
389// y = b1*b2*x*((1+b2*x)**(-1)) + e
390NIST_BEGIN(Misra1d)
391 b[0] * b[1] * x / (1.0 + b[1] * x)
392NIST_END
393
394const double kPi = 3.141592653589793238462643383279;
395// pi = 3.141592653589793238462643383279E0
396// y = b1 - b2*x - arctan[b3/(x-b4)]/pi + e
397NIST_BEGIN(Roszman1)
398 b[0] - b[1] * x - atan2(b[2], (x - b[3])) / kPi
399NIST_END
400
401// y = b1 / (1+exp[b2-b3*x]) + e
402NIST_BEGIN(Rat42)
403 b[0] / (1.0 + exp(b[1] - b[2] * x))
404NIST_END
405
406// y = b1 / ((1+exp[b2-b3*x])**(1/b4)) + e
407NIST_BEGIN(Rat43)
408 b[0] / pow(1.0 + exp(b[1] - b[2] * x), 1.0 / b[3])
409NIST_END
410
411// y = (b1 + b2*x + b3*x**2 + b4*x**3) /
412// (1 + b5*x + b6*x**2 + b7*x**3) + e
413NIST_BEGIN(Thurber)
414 (b[0] + b[1] * x + b[2] * x * x + b[3] * x * x * x) /
415 (1.0 + b[4] * x + b[5] * x * x + b[6] * x * x * x)
416NIST_END
417
418// y = b1 + b2*cos( 2*pi*x/12 ) + b3*sin( 2*pi*x/12 )
419// + b5*cos( 2*pi*x/b4 ) + b6*sin( 2*pi*x/b4 )
420// + b8*cos( 2*pi*x/b7 ) + b9*sin( 2*pi*x/b7 ) + e
421NIST_BEGIN(ENSO)
422 b[0] + b[1] * cos(2.0 * kPi * x / 12.0) +
423 b[2] * sin(2.0 * kPi * x / 12.0) +
424 b[4] * cos(2.0 * kPi * x / b[3]) +
425 b[5] * sin(2.0 * kPi * x / b[3]) +
426 b[7] * cos(2.0 * kPi * x / b[6]) +
427 b[8] * sin(2.0 * kPi * x / b[6])
428NIST_END
429
430// y = (b1/b2) * exp[-0.5*((x-b3)/b2)**2] + e
431NIST_BEGIN(Eckerle4)
432 b[0] / b[1] * exp(-0.5 * pow((x - b[2])/b[1], 2))
433NIST_END
434
435struct Nelson {
436 public:
437 Nelson(const double* const x, const double* const y, const int n)
438 : x_(x), y_(y), n_(n) {}
439
440 template <typename T>
441 bool operator()(const T* const b, T* residual) const {
442 // log[y] = b1 - b2*x1 * exp[-b3*x2] + e
443 for (int i = 0; i < n_; ++i) {
444 residual[i] = log(y_[i]) - (b[0] - b[1] * x_[2 * i] * exp(-b[2] * x_[2 * i + 1]));
445 }
446 return true;
447 }
448
449 private:
450 const double* x_;
451 const double* y_;
452 const int n_;
453};
454
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800455// clang-format on
456
Austin Schuh70cc9552019-01-21 19:46:48 -0800457static void SetNumericDiffOptions(ceres::NumericDiffOptions* options) {
Austin Schuh3de38b02024-06-25 18:25:10 -0700458 options->max_num_ridders_extrapolations =
459 CERES_GET_FLAG(FLAGS_ridders_extrapolations);
460 options->ridders_relative_initial_step_size =
461 CERES_GET_FLAG(FLAGS_ridders_step_size);
Austin Schuh70cc9552019-01-21 19:46:48 -0800462}
463
464void SetMinimizerOptions(ceres::Solver::Options* options) {
Austin Schuh3de38b02024-06-25 18:25:10 -0700465 CHECK(ceres::StringToMinimizerType(CERES_GET_FLAG(FLAGS_minimizer),
466 &options->minimizer_type));
467 CHECK(ceres::StringToLinearSolverType(CERES_GET_FLAG(FLAGS_linear_solver),
Austin Schuh70cc9552019-01-21 19:46:48 -0800468 &options->linear_solver_type));
Austin Schuh3de38b02024-06-25 18:25:10 -0700469 CHECK(StringToDenseLinearAlgebraLibraryType(
470 CERES_GET_FLAG(FLAGS_dense_linear_algebra_library),
471 &options->dense_linear_algebra_library_type));
472 CHECK(ceres::StringToPreconditionerType(CERES_GET_FLAG(FLAGS_preconditioner),
Austin Schuh70cc9552019-01-21 19:46:48 -0800473 &options->preconditioner_type));
474 CHECK(ceres::StringToTrustRegionStrategyType(
Austin Schuh3de38b02024-06-25 18:25:10 -0700475 CERES_GET_FLAG(FLAGS_trust_region_strategy),
476 &options->trust_region_strategy_type));
477 CHECK(ceres::StringToDoglegType(CERES_GET_FLAG(FLAGS_dogleg),
478 &options->dogleg_type));
Austin Schuh70cc9552019-01-21 19:46:48 -0800479 CHECK(ceres::StringToLineSearchDirectionType(
Austin Schuh3de38b02024-06-25 18:25:10 -0700480 CERES_GET_FLAG(FLAGS_line_search_direction),
481 &options->line_search_direction_type));
482 CHECK(ceres::StringToLineSearchType(CERES_GET_FLAG(FLAGS_line_search),
Austin Schuh70cc9552019-01-21 19:46:48 -0800483 &options->line_search_type));
484 CHECK(ceres::StringToLineSearchInterpolationType(
Austin Schuh3de38b02024-06-25 18:25:10 -0700485 CERES_GET_FLAG(FLAGS_line_search_interpolation),
Austin Schuh70cc9552019-01-21 19:46:48 -0800486 &options->line_search_interpolation_type));
487
Austin Schuh3de38b02024-06-25 18:25:10 -0700488 options->max_num_iterations = CERES_GET_FLAG(FLAGS_num_iterations);
489 options->use_nonmonotonic_steps = CERES_GET_FLAG(FLAGS_nonmonotonic_steps);
490 options->initial_trust_region_radius =
491 CERES_GET_FLAG(FLAGS_initial_trust_region_radius);
492 options->max_lbfgs_rank = CERES_GET_FLAG(FLAGS_lbfgs_rank);
493 options->line_search_sufficient_function_decrease =
494 CERES_GET_FLAG(FLAGS_sufficient_decrease);
Austin Schuh70cc9552019-01-21 19:46:48 -0800495 options->line_search_sufficient_curvature_decrease =
Austin Schuh3de38b02024-06-25 18:25:10 -0700496 CERES_GET_FLAG(FLAGS_sufficient_curvature_decrease);
Austin Schuh70cc9552019-01-21 19:46:48 -0800497 options->max_num_line_search_step_size_iterations =
Austin Schuh3de38b02024-06-25 18:25:10 -0700498 CERES_GET_FLAG(FLAGS_max_line_search_iterations);
Austin Schuh70cc9552019-01-21 19:46:48 -0800499 options->max_num_line_search_direction_restarts =
Austin Schuh3de38b02024-06-25 18:25:10 -0700500 CERES_GET_FLAG(FLAGS_max_line_search_restarts);
Austin Schuh70cc9552019-01-21 19:46:48 -0800501 options->use_approximate_eigenvalue_bfgs_scaling =
Austin Schuh3de38b02024-06-25 18:25:10 -0700502 CERES_GET_FLAG(FLAGS_approximate_eigenvalue_bfgs_scaling);
Austin Schuh70cc9552019-01-21 19:46:48 -0800503 options->function_tolerance = std::numeric_limits<double>::epsilon();
504 options->gradient_tolerance = std::numeric_limits<double>::epsilon();
505 options->parameter_tolerance = std::numeric_limits<double>::epsilon();
506}
507
Austin Schuh3de38b02024-06-25 18:25:10 -0700508std::string JoinPath(const std::string& dirname, const std::string& basename) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800509#ifdef _WIN32
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800510 static const char separator = '\\';
Austin Schuh70cc9552019-01-21 19:46:48 -0800511#else
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800512 static const char separator = '/';
Austin Schuh70cc9552019-01-21 19:46:48 -0800513#endif // _WIN32
514
515 if ((!basename.empty() && basename[0] == separator) || dirname.empty()) {
516 return basename;
517 } else if (dirname[dirname.size() - 1] == separator) {
518 return dirname + basename;
519 } else {
Austin Schuh3de38b02024-06-25 18:25:10 -0700520 return dirname + std::string(&separator, 1) + basename;
Austin Schuh70cc9552019-01-21 19:46:48 -0800521 }
522}
523
524template <typename Model, int num_parameters>
525CostFunction* CreateCostFunction(const Matrix& predictor,
526 const Matrix& response,
527 const int num_observations) {
Austin Schuh3de38b02024-06-25 18:25:10 -0700528 auto* model = new Model(predictor.data(), response.data(), num_observations);
529 ceres::CostFunction* cost_function = nullptr;
530 if (CERES_GET_FLAG(FLAGS_use_numeric_diff)) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800531 ceres::NumericDiffOptions options;
532 SetNumericDiffOptions(&options);
Austin Schuh3de38b02024-06-25 18:25:10 -0700533 if (CERES_GET_FLAG(FLAGS_numeric_diff_method) == "central") {
Austin Schuh70cc9552019-01-21 19:46:48 -0800534 cost_function = new NumericDiffCostFunction<Model,
535 ceres::CENTRAL,
536 ceres::DYNAMIC,
537 num_parameters>(
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800538 model, ceres::TAKE_OWNERSHIP, num_observations, options);
Austin Schuh3de38b02024-06-25 18:25:10 -0700539 } else if (CERES_GET_FLAG(FLAGS_numeric_diff_method) == "forward") {
Austin Schuh70cc9552019-01-21 19:46:48 -0800540 cost_function = new NumericDiffCostFunction<Model,
541 ceres::FORWARD,
542 ceres::DYNAMIC,
543 num_parameters>(
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800544 model, ceres::TAKE_OWNERSHIP, num_observations, options);
Austin Schuh3de38b02024-06-25 18:25:10 -0700545 } else if (CERES_GET_FLAG(FLAGS_numeric_diff_method) == "ridders") {
Austin Schuh70cc9552019-01-21 19:46:48 -0800546 cost_function = new NumericDiffCostFunction<Model,
547 ceres::RIDDERS,
548 ceres::DYNAMIC,
549 num_parameters>(
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800550 model, ceres::TAKE_OWNERSHIP, num_observations, options);
Austin Schuh70cc9552019-01-21 19:46:48 -0800551 } else {
552 LOG(ERROR) << "Invalid numeric diff method specified";
Austin Schuh3de38b02024-06-25 18:25:10 -0700553 return nullptr;
Austin Schuh70cc9552019-01-21 19:46:48 -0800554 }
555 } else {
556 cost_function =
557 new ceres::AutoDiffCostFunction<Model, ceres::DYNAMIC, num_parameters>(
558 model, num_observations);
559 }
560 return cost_function;
561}
562
563double ComputeLRE(const Matrix& expected, const Matrix& actual) {
564 // Compute the LRE by comparing each component of the solution
565 // with the ground truth, and taking the minimum.
566 const double kMaxNumSignificantDigits = 11;
567 double log_relative_error = kMaxNumSignificantDigits + 1;
568 for (int i = 0; i < expected.cols(); ++i) {
569 const double tmp_lre = -std::log10(std::fabs(expected(i) - actual(i)) /
570 std::fabs(expected(i)));
571 // The maximum LRE is capped at 11 - the precision at which the
572 // ground truth is known.
573 //
574 // The minimum LRE is capped at 0 - no digits match between the
575 // computed solution and the ground truth.
576 log_relative_error =
577 std::min(log_relative_error,
578 std::max(0.0, std::min(kMaxNumSignificantDigits, tmp_lre)));
579 }
580 return log_relative_error;
581}
582
583template <typename Model, int num_parameters>
Austin Schuh3de38b02024-06-25 18:25:10 -0700584int RegressionDriver(const std::string& filename) {
585 NISTProblem nist_problem(
586 JoinPath(CERES_GET_FLAG(FLAGS_nist_data_dir), filename));
Austin Schuh70cc9552019-01-21 19:46:48 -0800587 CHECK_EQ(num_parameters, nist_problem.num_parameters());
588
589 Matrix predictor = nist_problem.predictor();
590 Matrix response = nist_problem.response();
591 Matrix final_parameters = nist_problem.final_parameters();
592
593 printf("%s\n", filename.c_str());
594
595 // Each NIST problem comes with multiple starting points, so we
596 // construct the problem from scratch for each case and solve it.
597 int num_success = 0;
598 for (int start = 0; start < nist_problem.num_starts(); ++start) {
599 Matrix initial_parameters = nist_problem.initial_parameters(start);
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800600 ceres::CostFunction* cost_function =
601 CreateCostFunction<Model, num_parameters>(
602 predictor, response, nist_problem.num_observations());
Austin Schuh70cc9552019-01-21 19:46:48 -0800603
604 double initial_cost;
605 double final_cost;
606
Austin Schuh3de38b02024-06-25 18:25:10 -0700607 if (!CERES_GET_FLAG(FLAGS_use_tiny_solver)) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800608 ceres::Problem problem;
Austin Schuh3de38b02024-06-25 18:25:10 -0700609 problem.AddResidualBlock(
610 cost_function, nullptr, initial_parameters.data());
Austin Schuh70cc9552019-01-21 19:46:48 -0800611 ceres::Solver::Summary summary;
612 ceres::Solver::Options options;
613 SetMinimizerOptions(&options);
614 Solve(options, &problem, &summary);
615 initial_cost = summary.initial_cost;
616 final_cost = summary.final_cost;
617 } else {
618 ceres::TinySolverCostFunctionAdapter<Eigen::Dynamic, num_parameters> cfa(
619 *cost_function);
Austin Schuh3de38b02024-06-25 18:25:10 -0700620 using Solver = ceres::TinySolver<
621 ceres::TinySolverCostFunctionAdapter<Eigen::Dynamic, num_parameters>>;
Austin Schuh70cc9552019-01-21 19:46:48 -0800622 Solver solver;
Austin Schuh3de38b02024-06-25 18:25:10 -0700623 solver.options.max_num_iterations = CERES_GET_FLAG(FLAGS_num_iterations);
Austin Schuh70cc9552019-01-21 19:46:48 -0800624 solver.options.gradient_tolerance =
625 std::numeric_limits<double>::epsilon();
626 solver.options.parameter_tolerance =
627 std::numeric_limits<double>::epsilon();
Austin Schuh3de38b02024-06-25 18:25:10 -0700628 solver.options.function_tolerance = 0.0;
Austin Schuh70cc9552019-01-21 19:46:48 -0800629
630 Eigen::Matrix<double, num_parameters, 1> x;
631 x = initial_parameters.transpose();
632 typename Solver::Summary summary = solver.Solve(cfa, &x);
633 initial_parameters = x;
634 initial_cost = summary.initial_cost;
635 final_cost = summary.final_cost;
636 delete cost_function;
637 }
638
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800639 const double log_relative_error =
640 ComputeLRE(nist_problem.final_parameters(), initial_parameters);
Austin Schuh70cc9552019-01-21 19:46:48 -0800641 const int kMinNumMatchingDigits = 4;
642 if (log_relative_error > kMinNumMatchingDigits) {
643 ++num_success;
644 }
645
646 printf(
647 "start: %d status: %s lre: %4.1f initial cost: %e final cost:%e "
648 "certified cost: %e\n",
649 start + 1,
650 log_relative_error < kMinNumMatchingDigits ? "FAILURE" : "SUCCESS",
651 log_relative_error,
652 initial_cost,
653 final_cost,
654 nist_problem.certified_cost());
655 }
656 return num_success;
657}
658
Austin Schuh70cc9552019-01-21 19:46:48 -0800659void SolveNISTProblems() {
Austin Schuh3de38b02024-06-25 18:25:10 -0700660 if (CERES_GET_FLAG(FLAGS_nist_data_dir).empty()) {
Austin Schuh70cc9552019-01-21 19:46:48 -0800661 LOG(FATAL) << "Must specify the directory containing the NIST problems";
662 }
663
Austin Schuh3de38b02024-06-25 18:25:10 -0700664 std::cout << "Lower Difficulty\n";
Austin Schuh70cc9552019-01-21 19:46:48 -0800665 int easy_success = 0;
666 easy_success += RegressionDriver<Misra1a, 2>("Misra1a.dat");
667 easy_success += RegressionDriver<Chwirut, 3>("Chwirut1.dat");
668 easy_success += RegressionDriver<Chwirut, 3>("Chwirut2.dat");
669 easy_success += RegressionDriver<Lanczos, 6>("Lanczos3.dat");
670 easy_success += RegressionDriver<Gauss, 8>("Gauss1.dat");
671 easy_success += RegressionDriver<Gauss, 8>("Gauss2.dat");
672 easy_success += RegressionDriver<DanWood, 2>("DanWood.dat");
673 easy_success += RegressionDriver<Misra1b, 2>("Misra1b.dat");
674
Austin Schuh3de38b02024-06-25 18:25:10 -0700675 std::cout << "\nMedium Difficulty\n";
Austin Schuh70cc9552019-01-21 19:46:48 -0800676 int medium_success = 0;
677 medium_success += RegressionDriver<Kirby2, 5>("Kirby2.dat");
678 medium_success += RegressionDriver<Hahn1, 7>("Hahn1.dat");
679 medium_success += RegressionDriver<Nelson, 3>("Nelson.dat");
680 medium_success += RegressionDriver<MGH17, 5>("MGH17.dat");
681 medium_success += RegressionDriver<Lanczos, 6>("Lanczos1.dat");
682 medium_success += RegressionDriver<Lanczos, 6>("Lanczos2.dat");
683 medium_success += RegressionDriver<Gauss, 8>("Gauss3.dat");
684 medium_success += RegressionDriver<Misra1c, 2>("Misra1c.dat");
685 medium_success += RegressionDriver<Misra1d, 2>("Misra1d.dat");
686 medium_success += RegressionDriver<Roszman1, 4>("Roszman1.dat");
687 medium_success += RegressionDriver<ENSO, 9>("ENSO.dat");
688
Austin Schuh3de38b02024-06-25 18:25:10 -0700689 std::cout << "\nHigher Difficulty\n";
Austin Schuh70cc9552019-01-21 19:46:48 -0800690 int hard_success = 0;
691 hard_success += RegressionDriver<MGH09, 4>("MGH09.dat");
692 hard_success += RegressionDriver<Thurber, 7>("Thurber.dat");
693 hard_success += RegressionDriver<BoxBOD, 2>("BoxBOD.dat");
694 hard_success += RegressionDriver<Rat42, 3>("Rat42.dat");
695 hard_success += RegressionDriver<MGH10, 3>("MGH10.dat");
696 hard_success += RegressionDriver<Eckerle4, 3>("Eckerle4.dat");
697 hard_success += RegressionDriver<Rat43, 4>("Rat43.dat");
698 hard_success += RegressionDriver<Bennet5, 3>("Bennett5.dat");
699
Austin Schuh3de38b02024-06-25 18:25:10 -0700700 std::cout << "\n";
701 std::cout << "Easy : " << easy_success << "/16\n";
702 std::cout << "Medium : " << medium_success << "/22\n";
703 std::cout << "Hard : " << hard_success << "/16\n";
704 std::cout << "Total : " << easy_success + medium_success + hard_success
705 << "/54\n";
Austin Schuh70cc9552019-01-21 19:46:48 -0800706}
707
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800708} // namespace
Austin Schuh3de38b02024-06-25 18:25:10 -0700709} // namespace ceres::examples
Austin Schuh70cc9552019-01-21 19:46:48 -0800710
711int main(int argc, char** argv) {
Austin Schuh1d1e6ea2020-12-23 21:56:30 -0800712 GFLAGS_NAMESPACE::ParseCommandLineFlags(&argc, &argv, true);
Austin Schuh70cc9552019-01-21 19:46:48 -0800713 google::InitGoogleLogging(argv[0]);
714 ceres::examples::SolveNISTProblems();
715 return 0;
716}