blob: 8ce72915e774e9517abaff2ec91a2eec66cf2c79 [file] [log] [blame]
Austin Schuh70cc9552019-01-21 19:46:48 -08001// Ceres Solver - A fast non-linear least squares minimizer
2// Copyright 2017 Google Inc. All rights reserved.
3// http://ceres-solver.org/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Author: sameeragarwal@google.com (Sameer Agarwal)
30//
31// The National Institute of Standards and Technology has released a
32// set of problems to test non-linear least squares solvers.
33//
34// More information about the background on these problems and
35// suggested evaluation methodology can be found at:
36//
37// http://www.itl.nist.gov/div898/strd/nls/nls_info.shtml
38//
39// The problem data themselves can be found at
40//
41// http://www.itl.nist.gov/div898/strd/nls/nls_main.shtml
42//
43// The problems are divided into three levels of difficulty, Easy,
44// Medium and Hard. For each problem there are two starting guesses,
45// the first one far away from the global minimum and the second
46// closer to it.
47//
48// A problem is considered successfully solved, if every components of
49// the solution matches the globally optimal solution in at least 4
50// digits or more.
51//
52// This dataset was used for an evaluation of Non-linear least squares
53// solvers:
54//
55// P. F. Mondragon & B. Borchers, A Comparison of Nonlinear Regression
56// Codes, Journal of Modern Applied Statistical Methods, 4(1):343-351,
57// 2005.
58//
59// The results from Mondragon & Borchers can be summarized as
60// Excel Gnuplot GaussFit HBN MinPack
61// Average LRE 2.3 4.3 4.0 6.8 4.4
62// Winner 1 5 12 29 12
63//
64// Where the row Winner counts, the number of problems for which the
65// solver had the highest LRE.
66
67// In this file, we implement the same evaluation methodology using
68// Ceres. Currently using Levenberg-Marquardt with DENSE_QR, we get
69//
70// Excel Gnuplot GaussFit HBN MinPack Ceres
71// Average LRE 2.3 4.3 4.0 6.8 4.4 9.4
72// Winner 0 0 5 11 2 41
73
74#include <fstream>
75#include <iostream>
76#include <iterator>
77
78#include "Eigen/Core"
79#include "ceres/ceres.h"
80#include "ceres/tiny_solver.h"
81#include "ceres/tiny_solver_cost_function_adapter.h"
82#include "gflags/gflags.h"
83#include "glog/logging.h"
84
85DEFINE_bool(use_tiny_solver, false, "Use TinySolver instead of Ceres::Solver");
86DEFINE_string(nist_data_dir, "", "Directory containing the NIST non-linear"
87 "regression examples");
88DEFINE_string(minimizer, "trust_region",
89 "Minimizer type to use, choices are: line_search & trust_region");
90DEFINE_string(trust_region_strategy, "levenberg_marquardt",
91 "Options are: levenberg_marquardt, dogleg");
92DEFINE_string(dogleg, "traditional_dogleg",
93 "Options are: traditional_dogleg, subspace_dogleg");
94DEFINE_string(linear_solver, "dense_qr", "Options are: "
95 "sparse_cholesky, dense_qr, dense_normal_cholesky and"
96 "cgnr");
97DEFINE_string(preconditioner, "jacobi", "Options are: "
98 "identity, jacobi");
99DEFINE_string(line_search, "wolfe",
100 "Line search algorithm to use, choices are: armijo and wolfe.");
101DEFINE_string(line_search_direction, "lbfgs",
102 "Line search direction algorithm to use, choices: lbfgs, bfgs");
103DEFINE_int32(max_line_search_iterations, 20,
104 "Maximum number of iterations for each line search.");
105DEFINE_int32(max_line_search_restarts, 10,
106 "Maximum number of restarts of line search direction algorithm.");
107DEFINE_string(line_search_interpolation, "cubic",
108 "Degree of polynomial aproximation in line search, "
109 "choices are: bisection, quadratic & cubic.");
110DEFINE_int32(lbfgs_rank, 20,
111 "Rank of L-BFGS inverse Hessian approximation in line search.");
112DEFINE_bool(approximate_eigenvalue_bfgs_scaling, false,
113 "Use approximate eigenvalue scaling in (L)BFGS line search.");
114DEFINE_double(sufficient_decrease, 1.0e-4,
115 "Line search Armijo sufficient (function) decrease factor.");
116DEFINE_double(sufficient_curvature_decrease, 0.9,
117 "Line search Wolfe sufficient curvature decrease factor.");
118DEFINE_int32(num_iterations, 10000, "Number of iterations");
119DEFINE_bool(nonmonotonic_steps, false, "Trust region algorithm can use"
120 " nonmonotic steps");
121DEFINE_double(initial_trust_region_radius, 1e4, "Initial trust region radius");
122DEFINE_bool(use_numeric_diff, false,
123 "Use numeric differentiation instead of automatic "
124 "differentiation.");
125DEFINE_string(numeric_diff_method, "ridders", "When using numeric "
126 "differentiation, selects algorithm. Options are: central, "
127 "forward, ridders.");
128DEFINE_double(ridders_step_size, 1e-9, "Initial step size for Ridders "
129 "numeric differentiation.");
130DEFINE_int32(ridders_extrapolations, 3, "Maximal number of Ridders "
131 "extrapolations.");
132
133namespace ceres {
134namespace examples {
135
136using Eigen::Dynamic;
137using Eigen::RowMajor;
138typedef Eigen::Matrix<double, Dynamic, 1> Vector;
139typedef Eigen::Matrix<double, Dynamic, Dynamic, RowMajor> Matrix;
140
141using std::atof;
142using std::atoi;
143using std::cout;
144using std::ifstream;
145using std::string;
146using std::vector;
147
148void SplitStringUsingChar(const string& full,
149 const char delim,
150 vector<string>* result) {
151 std::back_insert_iterator< vector<string> > it(*result);
152
153 const char* p = full.data();
154 const char* end = p + full.size();
155 while (p != end) {
156 if (*p == delim) {
157 ++p;
158 } else {
159 const char* start = p;
160 while (++p != end && *p != delim) {
161 // Skip to the next occurence of the delimiter.
162 }
163 *it++ = string(start, p - start);
164 }
165 }
166}
167
168bool GetAndSplitLine(ifstream& ifs, vector<string>* pieces) {
169 pieces->clear();
170 char buf[256];
171 ifs.getline(buf, 256);
172 SplitStringUsingChar(string(buf), ' ', pieces);
173 return true;
174}
175
176void SkipLines(ifstream& ifs, int num_lines) {
177 char buf[256];
178 for (int i = 0; i < num_lines; ++i) {
179 ifs.getline(buf, 256);
180 }
181}
182
183class NISTProblem {
184 public:
185 explicit NISTProblem(const string& filename) {
186 ifstream ifs(filename.c_str(), ifstream::in);
187 CHECK(ifs) << "Unable to open : " << filename;
188
189 vector<string> pieces;
190 SkipLines(ifs, 24);
191 GetAndSplitLine(ifs, &pieces);
192 const int kNumResponses = atoi(pieces[1].c_str());
193
194 GetAndSplitLine(ifs, &pieces);
195 const int kNumPredictors = atoi(pieces[0].c_str());
196
197 GetAndSplitLine(ifs, &pieces);
198 const int kNumObservations = atoi(pieces[0].c_str());
199
200 SkipLines(ifs, 4);
201 GetAndSplitLine(ifs, &pieces);
202 const int kNumParameters = atoi(pieces[0].c_str());
203 SkipLines(ifs, 8);
204
205 // Get the first line of initial and final parameter values to
206 // determine the number of tries.
207 GetAndSplitLine(ifs, &pieces);
208 const int kNumTries = pieces.size() - 4;
209
210 predictor_.resize(kNumObservations, kNumPredictors);
211 response_.resize(kNumObservations, kNumResponses);
212 initial_parameters_.resize(kNumTries, kNumParameters);
213 final_parameters_.resize(1, kNumParameters);
214
215 // Parse the line for parameter b1.
216 int parameter_id = 0;
217 for (int i = 0; i < kNumTries; ++i) {
218 initial_parameters_(i, parameter_id) = atof(pieces[i + 2].c_str());
219 }
220 final_parameters_(0, parameter_id) = atof(pieces[2 + kNumTries].c_str());
221
222 // Parse the remaining parameter lines.
223 for (int parameter_id = 1; parameter_id < kNumParameters; ++parameter_id) {
224 GetAndSplitLine(ifs, &pieces);
225 // b2, b3, ....
226 for (int i = 0; i < kNumTries; ++i) {
227 initial_parameters_(i, parameter_id) = atof(pieces[i + 2].c_str());
228 }
229 final_parameters_(0, parameter_id) = atof(pieces[2 + kNumTries].c_str());
230 }
231
232 // Certfied cost
233 SkipLines(ifs, 1);
234 GetAndSplitLine(ifs, &pieces);
235 certified_cost_ = atof(pieces[4].c_str()) / 2.0;
236
237 // Read the observations.
238 SkipLines(ifs, 18 - kNumParameters);
239 for (int i = 0; i < kNumObservations; ++i) {
240 GetAndSplitLine(ifs, &pieces);
241 // Response.
242 for (int j = 0; j < kNumResponses; ++j) {
243 response_(i, j) = atof(pieces[j].c_str());
244 }
245
246 // Predictor variables.
247 for (int j = 0; j < kNumPredictors; ++j) {
248 predictor_(i, j) = atof(pieces[j + kNumResponses].c_str());
249 }
250 }
251 }
252
253 Matrix initial_parameters(int start) const { return initial_parameters_.row(start); } // NOLINT
254 Matrix final_parameters() const { return final_parameters_; }
255 Matrix predictor() const { return predictor_; }
256 Matrix response() const { return response_; }
257 int predictor_size() const { return predictor_.cols(); }
258 int num_observations() const { return predictor_.rows(); }
259 int response_size() const { return response_.cols(); }
260 int num_parameters() const { return initial_parameters_.cols(); }
261 int num_starts() const { return initial_parameters_.rows(); }
262 double certified_cost() const { return certified_cost_; }
263
264 private:
265 Matrix predictor_;
266 Matrix response_;
267 Matrix initial_parameters_;
268 Matrix final_parameters_;
269 double certified_cost_;
270};
271
272#define NIST_BEGIN(CostFunctionName) \
273 struct CostFunctionName { \
274 CostFunctionName(const double* const x, \
275 const double* const y, \
276 const int n) \
277 : x_(x), y_(y), n_(n) {} \
278 const double* x_; \
279 const double* y_; \
280 const int n_; \
281 template <typename T> \
282 bool operator()(const T* const b, T* residual) const { \
283 for (int i = 0; i < n_; ++i) { \
284 const T x(x_[i]); \
285 residual[i] = y_[i] - (
286
287#define NIST_END ); } return true; }};
288
289// y = b1 * (b2+x)**(-1/b3) + e
290NIST_BEGIN(Bennet5)
291 b[0] * pow(b[1] + x, -1.0 / b[2])
292NIST_END
293
294// y = b1*(1-exp[-b2*x]) + e
295NIST_BEGIN(BoxBOD)
296 b[0] * (1.0 - exp(-b[1] * x))
297NIST_END
298
299// y = exp[-b1*x]/(b2+b3*x) + e
300NIST_BEGIN(Chwirut)
301 exp(-b[0] * x) / (b[1] + b[2] * x)
302NIST_END
303
304// y = b1*x**b2 + e
305NIST_BEGIN(DanWood)
306 b[0] * pow(x, b[1])
307NIST_END
308
309// y = b1*exp( -b2*x ) + b3*exp( -(x-b4)**2 / b5**2 )
310// + b6*exp( -(x-b7)**2 / b8**2 ) + e
311NIST_BEGIN(Gauss)
312 b[0] * exp(-b[1] * x) +
313 b[2] * exp(-pow((x - b[3])/b[4], 2)) +
314 b[5] * exp(-pow((x - b[6])/b[7], 2))
315NIST_END
316
317// y = b1*exp(-b2*x) + b3*exp(-b4*x) + b5*exp(-b6*x) + e
318NIST_BEGIN(Lanczos)
319 b[0] * exp(-b[1] * x) + b[2] * exp(-b[3] * x) + b[4] * exp(-b[5] * x)
320NIST_END
321
322// y = (b1+b2*x+b3*x**2+b4*x**3) /
323// (1+b5*x+b6*x**2+b7*x**3) + e
324NIST_BEGIN(Hahn1)
325 (b[0] + b[1] * x + b[2] * x * x + b[3] * x * x * x) /
326 (1.0 + b[4] * x + b[5] * x * x + b[6] * x * x * x)
327NIST_END
328
329// y = (b1 + b2*x + b3*x**2) /
330// (1 + b4*x + b5*x**2) + e
331NIST_BEGIN(Kirby2)
332 (b[0] + b[1] * x + b[2] * x * x) /
333 (1.0 + b[3] * x + b[4] * x * x)
334NIST_END
335
336// y = b1*(x**2+x*b2) / (x**2+x*b3+b4) + e
337NIST_BEGIN(MGH09)
338 b[0] * (x * x + x * b[1]) / (x * x + x * b[2] + b[3])
339NIST_END
340
341// y = b1 * exp[b2/(x+b3)] + e
342NIST_BEGIN(MGH10)
343 b[0] * exp(b[1] / (x + b[2]))
344NIST_END
345
346// y = b1 + b2*exp[-x*b4] + b3*exp[-x*b5]
347NIST_BEGIN(MGH17)
348 b[0] + b[1] * exp(-x * b[3]) + b[2] * exp(-x * b[4])
349NIST_END
350
351// y = b1*(1-exp[-b2*x]) + e
352NIST_BEGIN(Misra1a)
353 b[0] * (1.0 - exp(-b[1] * x))
354NIST_END
355
356// y = b1 * (1-(1+b2*x/2)**(-2)) + e
357NIST_BEGIN(Misra1b)
358 b[0] * (1.0 - 1.0/ ((1.0 + b[1] * x / 2.0) * (1.0 + b[1] * x / 2.0))) // NOLINT
359NIST_END
360
361// y = b1 * (1-(1+2*b2*x)**(-.5)) + e
362NIST_BEGIN(Misra1c)
363 b[0] * (1.0 - pow(1.0 + 2.0 * b[1] * x, -0.5))
364NIST_END
365
366// y = b1*b2*x*((1+b2*x)**(-1)) + e
367NIST_BEGIN(Misra1d)
368 b[0] * b[1] * x / (1.0 + b[1] * x)
369NIST_END
370
371const double kPi = 3.141592653589793238462643383279;
372// pi = 3.141592653589793238462643383279E0
373// y = b1 - b2*x - arctan[b3/(x-b4)]/pi + e
374NIST_BEGIN(Roszman1)
375 b[0] - b[1] * x - atan2(b[2], (x - b[3])) / kPi
376NIST_END
377
378// y = b1 / (1+exp[b2-b3*x]) + e
379NIST_BEGIN(Rat42)
380 b[0] / (1.0 + exp(b[1] - b[2] * x))
381NIST_END
382
383// y = b1 / ((1+exp[b2-b3*x])**(1/b4)) + e
384NIST_BEGIN(Rat43)
385 b[0] / pow(1.0 + exp(b[1] - b[2] * x), 1.0 / b[3])
386NIST_END
387
388// y = (b1 + b2*x + b3*x**2 + b4*x**3) /
389// (1 + b5*x + b6*x**2 + b7*x**3) + e
390NIST_BEGIN(Thurber)
391 (b[0] + b[1] * x + b[2] * x * x + b[3] * x * x * x) /
392 (1.0 + b[4] * x + b[5] * x * x + b[6] * x * x * x)
393NIST_END
394
395// y = b1 + b2*cos( 2*pi*x/12 ) + b3*sin( 2*pi*x/12 )
396// + b5*cos( 2*pi*x/b4 ) + b6*sin( 2*pi*x/b4 )
397// + b8*cos( 2*pi*x/b7 ) + b9*sin( 2*pi*x/b7 ) + e
398NIST_BEGIN(ENSO)
399 b[0] + b[1] * cos(2.0 * kPi * x / 12.0) +
400 b[2] * sin(2.0 * kPi * x / 12.0) +
401 b[4] * cos(2.0 * kPi * x / b[3]) +
402 b[5] * sin(2.0 * kPi * x / b[3]) +
403 b[7] * cos(2.0 * kPi * x / b[6]) +
404 b[8] * sin(2.0 * kPi * x / b[6])
405NIST_END
406
407// y = (b1/b2) * exp[-0.5*((x-b3)/b2)**2] + e
408NIST_BEGIN(Eckerle4)
409 b[0] / b[1] * exp(-0.5 * pow((x - b[2])/b[1], 2))
410NIST_END
411
412struct Nelson {
413 public:
414 Nelson(const double* const x, const double* const y, const int n)
415 : x_(x), y_(y), n_(n) {}
416
417 template <typename T>
418 bool operator()(const T* const b, T* residual) const {
419 // log[y] = b1 - b2*x1 * exp[-b3*x2] + e
420 for (int i = 0; i < n_; ++i) {
421 residual[i] = log(y_[i]) - (b[0] - b[1] * x_[2 * i] * exp(-b[2] * x_[2 * i + 1]));
422 }
423 return true;
424 }
425
426 private:
427 const double* x_;
428 const double* y_;
429 const int n_;
430};
431
432static void SetNumericDiffOptions(ceres::NumericDiffOptions* options) {
433 options->max_num_ridders_extrapolations = FLAGS_ridders_extrapolations;
434 options->ridders_relative_initial_step_size = FLAGS_ridders_step_size;
435}
436
437void SetMinimizerOptions(ceres::Solver::Options* options) {
438 CHECK(
439 ceres::StringToMinimizerType(FLAGS_minimizer, &options->minimizer_type));
440 CHECK(ceres::StringToLinearSolverType(FLAGS_linear_solver,
441 &options->linear_solver_type));
442 CHECK(ceres::StringToPreconditionerType(FLAGS_preconditioner,
443 &options->preconditioner_type));
444 CHECK(ceres::StringToTrustRegionStrategyType(
445 FLAGS_trust_region_strategy, &options->trust_region_strategy_type));
446 CHECK(ceres::StringToDoglegType(FLAGS_dogleg, &options->dogleg_type));
447 CHECK(ceres::StringToLineSearchDirectionType(
448 FLAGS_line_search_direction, &options->line_search_direction_type));
449 CHECK(ceres::StringToLineSearchType(FLAGS_line_search,
450 &options->line_search_type));
451 CHECK(ceres::StringToLineSearchInterpolationType(
452 FLAGS_line_search_interpolation,
453 &options->line_search_interpolation_type));
454
455 options->max_num_iterations = FLAGS_num_iterations;
456 options->use_nonmonotonic_steps = FLAGS_nonmonotonic_steps;
457 options->initial_trust_region_radius = FLAGS_initial_trust_region_radius;
458 options->max_lbfgs_rank = FLAGS_lbfgs_rank;
459 options->line_search_sufficient_function_decrease = FLAGS_sufficient_decrease;
460 options->line_search_sufficient_curvature_decrease =
461 FLAGS_sufficient_curvature_decrease;
462 options->max_num_line_search_step_size_iterations =
463 FLAGS_max_line_search_iterations;
464 options->max_num_line_search_direction_restarts =
465 FLAGS_max_line_search_restarts;
466 options->use_approximate_eigenvalue_bfgs_scaling =
467 FLAGS_approximate_eigenvalue_bfgs_scaling;
468 options->function_tolerance = std::numeric_limits<double>::epsilon();
469 options->gradient_tolerance = std::numeric_limits<double>::epsilon();
470 options->parameter_tolerance = std::numeric_limits<double>::epsilon();
471}
472
473string JoinPath(const string& dirname, const string& basename) {
474#ifdef _WIN32
475 static const char separator = '\\';
476#else
477 static const char separator = '/';
478#endif // _WIN32
479
480 if ((!basename.empty() && basename[0] == separator) || dirname.empty()) {
481 return basename;
482 } else if (dirname[dirname.size() - 1] == separator) {
483 return dirname + basename;
484 } else {
485 return dirname + string(&separator, 1) + basename;
486 }
487}
488
489template <typename Model, int num_parameters>
490CostFunction* CreateCostFunction(const Matrix& predictor,
491 const Matrix& response,
492 const int num_observations) {
493 Model* model =
494 new Model(predictor.data(), response.data(), num_observations);
495 ceres::CostFunction* cost_function = NULL;
496 if (FLAGS_use_numeric_diff) {
497 ceres::NumericDiffOptions options;
498 SetNumericDiffOptions(&options);
499 if (FLAGS_numeric_diff_method == "central") {
500 cost_function = new NumericDiffCostFunction<Model,
501 ceres::CENTRAL,
502 ceres::DYNAMIC,
503 num_parameters>(
504 model,
505 ceres::TAKE_OWNERSHIP,
506 num_observations,
507 options);
508 } else if (FLAGS_numeric_diff_method == "forward") {
509 cost_function = new NumericDiffCostFunction<Model,
510 ceres::FORWARD,
511 ceres::DYNAMIC,
512 num_parameters>(
513 model,
514 ceres::TAKE_OWNERSHIP,
515 num_observations,
516 options);
517 } else if (FLAGS_numeric_diff_method == "ridders") {
518 cost_function = new NumericDiffCostFunction<Model,
519 ceres::RIDDERS,
520 ceres::DYNAMIC,
521 num_parameters>(
522 model,
523 ceres::TAKE_OWNERSHIP,
524 num_observations,
525 options);
526 } else {
527 LOG(ERROR) << "Invalid numeric diff method specified";
528 return 0;
529 }
530 } else {
531 cost_function =
532 new ceres::AutoDiffCostFunction<Model, ceres::DYNAMIC, num_parameters>(
533 model, num_observations);
534 }
535 return cost_function;
536}
537
538double ComputeLRE(const Matrix& expected, const Matrix& actual) {
539 // Compute the LRE by comparing each component of the solution
540 // with the ground truth, and taking the minimum.
541 const double kMaxNumSignificantDigits = 11;
542 double log_relative_error = kMaxNumSignificantDigits + 1;
543 for (int i = 0; i < expected.cols(); ++i) {
544 const double tmp_lre = -std::log10(std::fabs(expected(i) - actual(i)) /
545 std::fabs(expected(i)));
546 // The maximum LRE is capped at 11 - the precision at which the
547 // ground truth is known.
548 //
549 // The minimum LRE is capped at 0 - no digits match between the
550 // computed solution and the ground truth.
551 log_relative_error =
552 std::min(log_relative_error,
553 std::max(0.0, std::min(kMaxNumSignificantDigits, tmp_lre)));
554 }
555 return log_relative_error;
556}
557
558template <typename Model, int num_parameters>
559int RegressionDriver(const string& filename) {
560 NISTProblem nist_problem(JoinPath(FLAGS_nist_data_dir, filename));
561 CHECK_EQ(num_parameters, nist_problem.num_parameters());
562
563 Matrix predictor = nist_problem.predictor();
564 Matrix response = nist_problem.response();
565 Matrix final_parameters = nist_problem.final_parameters();
566
567 printf("%s\n", filename.c_str());
568
569 // Each NIST problem comes with multiple starting points, so we
570 // construct the problem from scratch for each case and solve it.
571 int num_success = 0;
572 for (int start = 0; start < nist_problem.num_starts(); ++start) {
573 Matrix initial_parameters = nist_problem.initial_parameters(start);
574 ceres::CostFunction* cost_function = CreateCostFunction<Model, num_parameters>(
575 predictor, response, nist_problem.num_observations());
576
577 double initial_cost;
578 double final_cost;
579
580 if (!FLAGS_use_tiny_solver) {
581 ceres::Problem problem;
582 problem.AddResidualBlock(cost_function, NULL, initial_parameters.data());
583 ceres::Solver::Summary summary;
584 ceres::Solver::Options options;
585 SetMinimizerOptions(&options);
586 Solve(options, &problem, &summary);
587 initial_cost = summary.initial_cost;
588 final_cost = summary.final_cost;
589 } else {
590 ceres::TinySolverCostFunctionAdapter<Eigen::Dynamic, num_parameters> cfa(
591 *cost_function);
592 typedef ceres::TinySolver<
593 ceres::TinySolverCostFunctionAdapter<Eigen::Dynamic, num_parameters> >
594 Solver;
595 Solver solver;
596 solver.options.max_num_iterations = FLAGS_num_iterations;
597 solver.options.gradient_tolerance =
598 std::numeric_limits<double>::epsilon();
599 solver.options.parameter_tolerance =
600 std::numeric_limits<double>::epsilon();
601
602 Eigen::Matrix<double, num_parameters, 1> x;
603 x = initial_parameters.transpose();
604 typename Solver::Summary summary = solver.Solve(cfa, &x);
605 initial_parameters = x;
606 initial_cost = summary.initial_cost;
607 final_cost = summary.final_cost;
608 delete cost_function;
609 }
610
611 const double log_relative_error = ComputeLRE(nist_problem.final_parameters(),
612 initial_parameters);
613 const int kMinNumMatchingDigits = 4;
614 if (log_relative_error > kMinNumMatchingDigits) {
615 ++num_success;
616 }
617
618 printf(
619 "start: %d status: %s lre: %4.1f initial cost: %e final cost:%e "
620 "certified cost: %e\n",
621 start + 1,
622 log_relative_error < kMinNumMatchingDigits ? "FAILURE" : "SUCCESS",
623 log_relative_error,
624 initial_cost,
625 final_cost,
626 nist_problem.certified_cost());
627 }
628 return num_success;
629}
630
631
632void SolveNISTProblems() {
633 if (FLAGS_nist_data_dir.empty()) {
634 LOG(FATAL) << "Must specify the directory containing the NIST problems";
635 }
636
637 cout << "Lower Difficulty\n";
638 int easy_success = 0;
639 easy_success += RegressionDriver<Misra1a, 2>("Misra1a.dat");
640 easy_success += RegressionDriver<Chwirut, 3>("Chwirut1.dat");
641 easy_success += RegressionDriver<Chwirut, 3>("Chwirut2.dat");
642 easy_success += RegressionDriver<Lanczos, 6>("Lanczos3.dat");
643 easy_success += RegressionDriver<Gauss, 8>("Gauss1.dat");
644 easy_success += RegressionDriver<Gauss, 8>("Gauss2.dat");
645 easy_success += RegressionDriver<DanWood, 2>("DanWood.dat");
646 easy_success += RegressionDriver<Misra1b, 2>("Misra1b.dat");
647
648 cout << "\nMedium Difficulty\n";
649 int medium_success = 0;
650 medium_success += RegressionDriver<Kirby2, 5>("Kirby2.dat");
651 medium_success += RegressionDriver<Hahn1, 7>("Hahn1.dat");
652 medium_success += RegressionDriver<Nelson, 3>("Nelson.dat");
653 medium_success += RegressionDriver<MGH17, 5>("MGH17.dat");
654 medium_success += RegressionDriver<Lanczos, 6>("Lanczos1.dat");
655 medium_success += RegressionDriver<Lanczos, 6>("Lanczos2.dat");
656 medium_success += RegressionDriver<Gauss, 8>("Gauss3.dat");
657 medium_success += RegressionDriver<Misra1c, 2>("Misra1c.dat");
658 medium_success += RegressionDriver<Misra1d, 2>("Misra1d.dat");
659 medium_success += RegressionDriver<Roszman1, 4>("Roszman1.dat");
660 medium_success += RegressionDriver<ENSO, 9>("ENSO.dat");
661
662 cout << "\nHigher Difficulty\n";
663 int hard_success = 0;
664 hard_success += RegressionDriver<MGH09, 4>("MGH09.dat");
665 hard_success += RegressionDriver<Thurber, 7>("Thurber.dat");
666 hard_success += RegressionDriver<BoxBOD, 2>("BoxBOD.dat");
667 hard_success += RegressionDriver<Rat42, 3>("Rat42.dat");
668 hard_success += RegressionDriver<MGH10, 3>("MGH10.dat");
669 hard_success += RegressionDriver<Eckerle4, 3>("Eckerle4.dat");
670 hard_success += RegressionDriver<Rat43, 4>("Rat43.dat");
671 hard_success += RegressionDriver<Bennet5, 3>("Bennett5.dat");
672
673 cout << "\n";
674 cout << "Easy : " << easy_success << "/16\n";
675 cout << "Medium : " << medium_success << "/22\n";
676 cout << "Hard : " << hard_success << "/16\n";
677 cout << "Total : " << easy_success + medium_success + hard_success
678 << "/54\n";
679}
680
681} // namespace examples
682} // namespace ceres
683
684int main(int argc, char** argv) {
685 CERES_GFLAGS_NAMESPACE::ParseCommandLineFlags(&argc, &argv, true);
686 google::InitGoogleLogging(argv[0]);
687 ceres::examples::SolveNISTProblems();
688 return 0;
689}