blob: cc93d73c6ed0044b48c672417d05b213c51ea5da [file] [log] [blame]
Austin Schuh70cc9552019-01-21 19:46:48 -08001// Ceres Solver - A fast non-linear least squares minimizer
2// Copyright 2017 Google Inc. All rights reserved.
3// http://ceres-solver.org/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Author: mierle@gmail.com (Keir Mierle)
30//
31// WARNING WARNING WARNING
32// WARNING WARNING WARNING Tiny solver is experimental and will change.
33// WARNING WARNING WARNING
34
35#ifndef CERES_PUBLIC_TINY_SOLVER_AUTODIFF_FUNCTION_H_
36#define CERES_PUBLIC_TINY_SOLVER_AUTODIFF_FUNCTION_H_
37
38#include <memory>
39#include <type_traits>
40#include "Eigen/Core"
41
42#include "ceres/jet.h"
43#include "ceres/types.h" // For kImpossibleValue.
44
45namespace ceres {
46
47// An adapter around autodiff-style CostFunctors to enable easier use of
48// TinySolver. See the example below showing how to use it:
49//
50// // Example for cost functor with static residual size.
51// // Same as an autodiff cost functor, but taking only 1 parameter.
52// struct MyFunctor {
53// template<typename T>
54// bool operator()(const T* const parameters, T* residuals) const {
55// const T& x = parameters[0];
56// const T& y = parameters[1];
57// const T& z = parameters[2];
58// residuals[0] = x + 2.*y + 4.*z;
59// residuals[1] = y * z;
60// return true;
61// }
62// };
63//
64// typedef TinySolverAutoDiffFunction<MyFunctor, 2, 3>
65// AutoDiffFunction;
66//
67// MyFunctor my_functor;
68// AutoDiffFunction f(my_functor);
69//
70// Vec3 x = ...;
71// TinySolver<AutoDiffFunction> solver;
72// solver.Solve(f, &x);
73//
74// // Example for cost functor with dynamic residual size.
75// // NumResiduals() supplies dynamic size of residuals.
76// // Same functionality as in tiny_solver.h but with autodiff.
77// struct MyFunctorWithDynamicResiduals {
78// int NumResiduals() const {
79// return 2;
80// }
81//
82// template<typename T>
83// bool operator()(const T* const parameters, T* residuals) const {
84// const T& x = parameters[0];
85// const T& y = parameters[1];
86// const T& z = parameters[2];
87// residuals[0] = x + static_cast<T>(2.)*y + static_cast<T>(4.)*z;
88// residuals[1] = y * z;
89// return true;
90// }
91// };
92//
93// typedef TinySolverAutoDiffFunction<MyFunctorWithDynamicResiduals,
94// Eigen::Dynamic,
95// 3>
96// AutoDiffFunctionWithDynamicResiduals;
97//
98// MyFunctorWithDynamicResiduals my_functor_dyn;
99// AutoDiffFunctionWithDynamicResiduals f(my_functor_dyn);
100//
101// Vec3 x = ...;
102// TinySolver<AutoDiffFunctionWithDynamicResiduals> solver;
103// solver.Solve(f, &x);
104//
105// WARNING: The cost function adapter is not thread safe.
106template<typename CostFunctor,
107 int kNumResiduals,
108 int kNumParameters,
109 typename T = double>
110class TinySolverAutoDiffFunction {
111 public:
112 TinySolverAutoDiffFunction(const CostFunctor& cost_functor)
113 : cost_functor_(cost_functor) {
114 Initialize<kNumResiduals>(cost_functor);
115 }
116
117 typedef T Scalar;
118 enum {
119 NUM_PARAMETERS = kNumParameters,
120 NUM_RESIDUALS = kNumResiduals,
121 };
122
123 // This is similar to AutoDifferentiate(), but since there is only one
124 // parameter block it is easier to inline to avoid overhead.
125 bool operator()(const T* parameters,
126 T* residuals,
127 T* jacobian) const {
128 if (jacobian == NULL) {
129 // No jacobian requested, so just directly call the cost function with
130 // doubles, skipping jets and derivatives.
131 return cost_functor_(parameters, residuals);
132 }
133 // Initialize the input jets with passed parameters.
134 for (int i = 0; i < kNumParameters; ++i) {
135 jet_parameters_[i].a = parameters[i]; // Scalar part.
136 jet_parameters_[i].v.setZero(); // Derivative part.
137 jet_parameters_[i].v[i] = T(1.0);
138 }
139
140 // Initialize the output jets such that we can detect user errors.
141 for (int i = 0; i < num_residuals_; ++i) {
142 jet_residuals_[i].a = kImpossibleValue;
143 jet_residuals_[i].v.setConstant(kImpossibleValue);
144 }
145
146 // Execute the cost function, but with jets to find the derivative.
147 if (!cost_functor_(jet_parameters_, jet_residuals_.data())) {
148 return false;
149 }
150
151 // Copy the jacobian out of the derivative part of the residual jets.
152 Eigen::Map<Eigen::Matrix<T, kNumResiduals, kNumParameters>> jacobian_matrix(
153 jacobian,
154 num_residuals_,
155 kNumParameters);
156 for (int r = 0; r < num_residuals_; ++r) {
157 residuals[r] = jet_residuals_[r].a;
158 // Note that while this looks like a fast vectorized write, in practice it
159 // unfortunately thrashes the cache since the writes to the column-major
160 // jacobian are strided (e.g. rows are non-contiguous).
161 jacobian_matrix.row(r) = jet_residuals_[r].v;
162 }
163 return true;
164 }
165
166 int NumResiduals() const {
167 return num_residuals_; // Set by Initialize.
168 }
169
170 private:
171 const CostFunctor& cost_functor_;
172
173 // The number of residuals at runtime.
174 // This will be overriden if NUM_RESIDUALS == Eigen::Dynamic.
175 int num_residuals_ = kNumResiduals;
176
177 // To evaluate the cost function with jets, temporary storage is needed. These
178 // are the buffers that are used during evaluation; parameters for the input,
179 // and jet_residuals_ are where the final cost and derivatives end up.
180 //
181 // Since this buffer is used for evaluation, the adapter is not thread safe.
182 using JetType = Jet<T, kNumParameters>;
183 mutable JetType jet_parameters_[kNumParameters];
184 // Eigen::Matrix serves as static or dynamic container.
185 mutable Eigen::Matrix<JetType, kNumResiduals, 1> jet_residuals_;
186
187 // The number of residuals is dynamically sized and the number of
188 // parameters is statically sized.
189 template<int R>
190 typename std::enable_if<(R == Eigen::Dynamic), void>::type Initialize(
191 const CostFunctor& function) {
192 jet_residuals_.resize(function.NumResiduals());
193 num_residuals_ = function.NumResiduals();
194 }
195
196 // The number of parameters and residuals are statically sized.
197 template<int R>
198 typename std::enable_if<(R != Eigen::Dynamic), void>::type Initialize(
199 const CostFunctor& /* function */) {
200 num_residuals_ = kNumResiduals;
201 }
202};
203
204} // namespace ceres
205
206#endif // CERES_PUBLIC_TINY_SOLVER_AUTODIFF_FUNCTION_H_