blob: aa14c1adb568aef7c4f1cd45eefa16bfb2d35f19 [file] [log] [blame]
Austin Schuhe9ae4562023-04-25 22:18:18 -07001#include <sys/types.h>
2#include <unistd.h>
3
4#include <Eigen/Dense>
5#include <iomanip>
6
7#include "absl/strings/str_join.h"
8#include "glog/logging.h"
9
10namespace frc971 {
11namespace solvers {
12
13// TODO(austin): Steal JET from Ceres to generate the derivatives easily and
14// quickly?
15//
16// States is the number of inputs to the optimization problem.
17// M is the number of inequality constraints.
18// N is the number of equality constraints.
19template <size_t States, size_t M, size_t N>
20class ConvexProblem {
21 public:
22 // Returns the function to minimize and it's derivatives.
23 virtual double f0(
24 Eigen::Ref<const Eigen::Matrix<double, States, 1>> X) const = 0;
25 virtual Eigen::Matrix<double, States, 1> df0(
26 Eigen::Ref<const Eigen::Matrix<double, States, 1>> X) const = 0;
27 virtual Eigen::Matrix<double, States, States> ddf0(
28 Eigen::Ref<const Eigen::Matrix<double, States, 1>> X) const = 0;
29
30 // Returns the constraints f(X) < 0, and their derivative.
31 virtual Eigen::Matrix<double, M, 1> f(
32 Eigen::Ref<const Eigen::Matrix<double, States, 1>> X) const = 0;
33 virtual Eigen::Matrix<double, M, States> df(
34 Eigen::Ref<const Eigen::Matrix<double, States, 1>> X) const = 0;
35
36 // Returns the equality constraints of the form A x = b
37 virtual Eigen::Matrix<double, N, States> A() const = 0;
38 virtual Eigen::Matrix<double, N, 1> b() const = 0;
39};
40
41// Implements a Primal-Dual Interior point method convex solver.
42// See 11.7 of https://web.stanford.edu/~boyd/cvxbook/bv_cvxbook.pdf
43//
44// States is the number of inputs to the optimization problem.
45// M is the number of inequality constraints.
46// N is the number of equality constraints.
47template <size_t States, size_t M, size_t N>
48class Solver {
49 public:
50 // Ratio to require the cost to decrease when line searching.
51 static constexpr double kAlpha = 0.05;
52 // Line search step parameter.
53 static constexpr double kBeta = 0.5;
54 static constexpr double kMu = 2.0;
55 // Terminal condition for the primal problem (equality constraints) and dual
56 // (gradient + inequality constraints).
57 static constexpr double kEpsilonF = 1e-6;
58 // Terminal condition for nu, the surrogate duality gap.
59 static constexpr double kEpsilon = 1e-6;
60
61 // Solves the problem given a feasible initial solution.
62 Eigen::Matrix<double, States, 1> Solve(
63 const ConvexProblem<States, M, N> &problem,
64 Eigen::Ref<const Eigen::Matrix<double, States, 1>> X_initial);
65
66 private:
67 // Class to hold all the derivataves and function evaluations.
68 struct Derivatives {
69 Eigen::Matrix<double, States, 1> gradient;
70 Eigen::Matrix<double, States, States> hessian;
71
72 // Inequality function f
73 Eigen::Matrix<double, M, 1> f;
74 // df
75 Eigen::Matrix<double, M, States> df;
76
77 // ddf is assumed to be 0 because for the linear constraint distance
78 // function we are using, it is actually 0, and by assuming it is zero
79 // rather than passing it through as 0 to the solver, we can save enough CPU
80 // to make it worth it.
81
82 // A
83 Eigen::Matrix<double, N, States> A;
84 // Ax - b
85 Eigen::Matrix<double, N, 1> Axmb;
86 };
87
88 // Computes all the values for the given problem at the given state.
89 Derivatives ComputeDerivative(
90 const ConvexProblem<States, M, N> &problem,
91 const Eigen::Ref<const Eigen::Matrix<double, States + M + N, 1>> y);
92
93 // Computes Rt at the given state and with the given t_inverse. See 11.53 of
94 // cvxbook.pdf.
95 Eigen::Matrix<double, States + M + N, 1> Rt(
96 const Derivatives &derivatives,
97 Eigen::Matrix<double, States + M + N, 1> y, double t_inverse);
98
99 // Prints out all the derivatives with VLOG at the provided verbosity.
100 void PrintDerivatives(
101 const Derivatives &derivatives,
102 const Eigen::Ref<const Eigen::Matrix<double, States + M + N, 1>> y,
103 std::string_view prefix, int verbosity);
104};
105
106template <size_t States, size_t M, size_t N>
107Eigen::Matrix<double, States + M + N, 1> Solver<States, M, N>::Rt(
108 const Derivatives &derivatives, Eigen::Matrix<double, States + M + N, 1> y,
109 double t_inverse) {
110 Eigen::Matrix<double, States + M + N, 1> result;
111
112 Eigen::Ref<Eigen::Matrix<double, States, 1>> r_dual =
113 result.template block<States, 1>(0, 0);
114 Eigen::Ref<Eigen::Matrix<double, M, 1>> r_cent =
115 result.template block<M, 1>(States, 0);
116 Eigen::Ref<Eigen::Matrix<double, N, 1>> r_pri =
117 result.template block<N, 1>(States + M, 0);
118
119 Eigen::Ref<const Eigen::Matrix<double, M, 1>> lambda =
120 y.template block<M, 1>(States, 0);
121 Eigen::Ref<const Eigen::Matrix<double, N, 1>> v =
122 y.template block<N, 1>(States + M, 0);
123
124 r_dual = derivatives.gradient + derivatives.df.transpose() * lambda +
125 derivatives.A.transpose() * v;
126 r_cent = -(Eigen::DiagonalMatrix<double, M>(lambda) * derivatives.f +
127 t_inverse * Eigen::Matrix<double, M, 1>::Ones());
128 r_pri = derivatives.Axmb;
129
130 return result;
131}
132
133template <size_t States, size_t M, size_t N>
134Eigen::Matrix<double, States, 1> Solver<States, M, N>::Solve(
135 const ConvexProblem<States, M, N> &problem,
136 Eigen::Ref<const Eigen::Matrix<double, States, 1>> X_initial) {
137 const Eigen::IOFormat kHeavyFormat(Eigen::StreamPrecision, 0, ", ",
138 ",\n "
139 " ",
140 "[", "]", "[", "]");
141
142 Eigen::Matrix<double, States + M + N, 1> y =
143 Eigen::Matrix<double, States + M + N, 1>::Constant(1.0);
144 y.template block<States, 1>(0, 0) = X_initial;
145
146 Derivatives derivatives = ComputeDerivative(problem, y);
147
148 for (size_t i = 0; i < M; ++i) {
149 CHECK_LE(derivatives.f(i, 0), 0.0)
150 << ": Initial state " << X_initial.transpose().format(kHeavyFormat)
151 << " not feasible";
152 }
153
154 PrintDerivatives(derivatives, y, "", 1);
155
156 size_t iteration = 0;
157 while (true) {
158 // Solve for the primal-dual search direction by solving the newton step.
159 Eigen::Ref<const Eigen::Matrix<double, M, 1>> lambda =
160 y.template block<M, 1>(States, 0);
161
162 const double nu = -(derivatives.f.transpose() * lambda)(0, 0);
163 const double t_inverse = nu / (kMu * lambda.rows());
164 Eigen::Matrix<double, States + M + N, 1> rt_orig =
165 Rt(derivatives, y, t_inverse);
166
167 Eigen::Matrix<double, States + M + N, States + M + N> m1;
168 m1.setZero();
169 m1.template block<States, States>(0, 0) = derivatives.hessian;
170 m1.template block<States, M>(0, States) = derivatives.df.transpose();
171 m1.template block<States, N>(0, States + M) = derivatives.A.transpose();
172 m1.template block<M, States>(States, 0) =
173 -(Eigen::DiagonalMatrix<double, M>(lambda) * derivatives.df);
174 m1.template block<M, M>(States, States) -=
175 Eigen::DiagonalMatrix<double, M>(derivatives.f);
176 m1.template block<N, States>(States + M, 0) = derivatives.A;
177
178 Eigen::Matrix<double, States + M + N, 1> dy =
179 m1.colPivHouseholderQr().solve(-rt_orig);
180
181 Eigen::Ref<Eigen::Matrix<double, M, 1>> dlambda =
182 dy.template block<M, 1>(States, 0);
183
184 double s = 1.0;
185
186 // Now, time to do line search.
187 //
188 // Start by keeping lambda positive. Make sure our step doesn't let
189 // lambda cross 0.
190 for (int i = 0; i < dlambda.rows(); ++i) {
191 if (lambda(i) + s * dlambda(i) < 0.0) {
192 // Ignore tiny steps in lambda. They cause issues when we get really
193 // close to having our constraints met but haven't converged the rest
194 // of the problem and start to run into rounding issues in the matrix
195 // solve portion.
196 if (dlambda(i) < 0.0 && dlambda(i) > -1e-12) {
197 VLOG(1) << " lambda(" << i << ") " << lambda(i) << " + " << s
198 << " * " << dlambda(i) << " -> s would be now "
199 << -lambda(i) / dlambda(i);
200 dlambda(i) = 0.0;
201 VLOG(1) << " dy -> " << std::setprecision(12) << std::fixed
202 << std::setfill(' ') << dy.transpose().format(kHeavyFormat);
203 continue;
204 }
205 VLOG(1) << " lambda(" << i << ") " << lambda(i) << " + " << s << " * "
206 << dlambda(i) << " -> s now " << -lambda(i) / dlambda(i);
207 s = -lambda(i) / dlambda(i);
208 }
209 }
210
211 VLOG(1) << " After lambda line search, s is " << s;
212
213 VLOG(3) << " Initial step " << iteration << " -> " << std::setprecision(12)
214 << std::fixed << std::setfill(' ')
215 << dy.transpose().format(kHeavyFormat);
216 VLOG(3) << " rt -> "
217 << std::setprecision(12) << std::fixed << std::setfill(' ')
218 << rt_orig.transpose().format(kHeavyFormat);
219
220 const double rt_orig_squared_norm = rt_orig.squaredNorm();
221
222 Eigen::Matrix<double, States + M + N, 1> next_y;
223 Eigen::Matrix<double, States + M + N, 1> rt;
224 Derivatives next_derivatives;
225 while (true) {
226 next_y = y + s * dy;
227 next_derivatives = ComputeDerivative(problem, next_y);
228 rt = Rt(next_derivatives, next_y, t_inverse);
229
230 const Eigen::Ref<const Eigen::VectorXd> next_x =
231 next_y.block(0, 0, next_derivatives.hessian.rows(), 1);
232 const Eigen::Ref<const Eigen::VectorXd> next_lambda =
233 next_y.block(next_x.rows(), 0, next_derivatives.f.rows(), 1);
234
235 const Eigen::Ref<const Eigen::VectorXd> next_v = next_y.block(
236 next_x.rows() + next_lambda.rows(), 0, next_derivatives.A.rows(), 1);
237
238 VLOG(1) << " next_rt(" << iteration << ") is " << rt.norm() << " -> "
239 << std::setprecision(12) << std::fixed << std::setfill(' ')
240 << rt.transpose().format(kHeavyFormat);
241
242 PrintDerivatives(next_derivatives, next_y, "next_", 3);
243
244 if (next_derivatives.f.maxCoeff() > 0.0) {
245 VLOG(1) << " f_next > 0.0 -> " << next_derivatives.f.maxCoeff()
246 << ", continuing line search.";
247 s *= kBeta;
248 } else if (next_derivatives.Axmb.squaredNorm() < 0.1 &&
249 rt.squaredNorm() >
250 std::pow(1.0 - kAlpha * s, 2.0) * rt_orig_squared_norm) {
251 VLOG(1) << " |Rt| > |Rt+1| " << rt.norm() << " > " << rt_orig.norm()
252 << ", drt -> " << std::setprecision(12) << std::fixed
253 << std::setfill(' ')
254 << (rt_orig - rt).transpose().format(kHeavyFormat);
255 s *= kBeta;
256 } else {
257 break;
258 }
259 }
260
261 VLOG(1) << " Terminated line search with s " << s << ", " << rt.norm()
262 << "(|Rt+1|) < " << rt_orig.norm() << "(|Rt|)";
263 y = next_y;
264
265 const Eigen::Ref<const Eigen::VectorXd> next_lambda =
266 y.template block<M, 1>(States, 0);
267
268 // See if we hit our convergence criteria.
269 const double r_primal_squared_norm =
270 rt.template block<N, 1>(States + M, 0).squaredNorm();
271 VLOG(1) << " rt_next(" << iteration << ") is " << rt.norm() << " -> "
272 << std::setprecision(12) << std::fixed << std::setfill(' ')
273 << rt.transpose().format(kHeavyFormat);
274 if (r_primal_squared_norm < kEpsilonF * kEpsilonF) {
275 const double r_dual_squared_norm =
276 rt.template block<States, 1>(0, 0).squaredNorm();
277 if (r_dual_squared_norm < kEpsilonF * kEpsilonF) {
278 const double next_nu =
279 -(next_derivatives.f.transpose() * next_lambda)(0, 0);
280 if (next_nu < kEpsilon) {
281 VLOG(1) << " r_primal(" << iteration << ") -> "
282 << std::sqrt(r_primal_squared_norm) << " < " << kEpsilonF
283 << ", r_dual(" << iteration << ") -> "
284 << std::sqrt(r_dual_squared_norm) << " < " << kEpsilonF
285 << ", nu(" << iteration << ") -> " << next_nu << " < "
286 << kEpsilon;
287 break;
288 } else {
289 VLOG(1) << " nu(" << iteration << ") -> " << next_nu << " < "
290 << kEpsilon << ", not done yet";
291 }
292
293 } else {
294 VLOG(1) << " r_dual(" << iteration << ") -> "
295 << std::sqrt(r_dual_squared_norm) << " < " << kEpsilonF
296 << ", not done yet";
297 }
298 } else {
299 VLOG(1) << " r_primal(" << iteration << ") -> "
300 << std::sqrt(r_primal_squared_norm) << " < " << kEpsilonF
301 << ", not done yet";
302 }
303 VLOG(1) << " step(" << iteration << ") " << std::setprecision(12)
304 << (s * dy).transpose().format(kHeavyFormat);
305 VLOG(1) << " y(" << iteration << ") is now " << std::setprecision(12)
306 << y.transpose().format(kHeavyFormat);
307
308 // Very import, use the last set of derivatives we picked for our new y
309 // for the next iteration. This avoids re-computing it.
310 derivatives = std::move(next_derivatives);
311
312 ++iteration;
313 if (iteration > 100) {
314 LOG(FATAL) << "Too many iterations";
315 }
316 }
317
318 return y.template block<States, 1>(0, 0);
319}
320
321template <size_t States, size_t M, size_t N>
322typename Solver<States, M, N>::Derivatives
323Solver<States, M, N>::ComputeDerivative(
324 const ConvexProblem<States, M, N> &problem,
325 const Eigen::Ref<const Eigen::Matrix<double, States + M + N, 1>> y) {
326 const Eigen::Ref<const Eigen::Matrix<double, States, 1>> x =
327 y.template block<States, 1>(0, 0);
328
329 Derivatives derivatives;
330 derivatives.gradient = problem.df0(x);
331 derivatives.hessian = problem.ddf0(x);
332 derivatives.f = problem.f(x);
333 derivatives.df = problem.df(x);
334 derivatives.A = problem.A();
335 derivatives.Axmb =
336 derivatives.A * y.template block<States, 1>(0, 0) - problem.b();
337 return derivatives;
338}
339
340template <size_t States, size_t M, size_t N>
341void Solver<States, M, N>::PrintDerivatives(
342 const Derivatives &derivatives,
343 const Eigen::Ref<const Eigen::Matrix<double, States + M + N, 1>> y,
344 std::string_view prefix, int verbosity) {
345 const Eigen::Ref<const Eigen::VectorXd> x =
346 y.block(0, 0, derivatives.hessian.rows(), 1);
347 const Eigen::Ref<const Eigen::VectorXd> lambda =
348 y.block(x.rows(), 0, derivatives.f.rows(), 1);
349
350 if (VLOG_IS_ON(verbosity)) {
351 Eigen::IOFormat heavy(Eigen::StreamPrecision, 0, ", ",
352 ",\n "
353 " ",
354 "[", "]", "[", "]");
355 heavy.rowSeparator =
356 heavy.rowSeparator +
357 std::string(absl::StrCat(getpid()).size() + prefix.size(), ' ');
358
359 const Eigen::Ref<const Eigen::VectorXd> v =
360 y.block(x.rows() + lambda.rows(), 0, derivatives.A.rows(), 1);
361 VLOG(verbosity) << " " << prefix << "x: " << x.transpose().format(heavy);
362 VLOG(verbosity) << " " << prefix
363 << "lambda: " << lambda.transpose().format(heavy);
364 VLOG(verbosity) << " " << prefix << "v: " << v.transpose().format(heavy);
365 VLOG(verbosity) << " " << prefix
366 << "hessian: " << derivatives.hessian.format(heavy);
367 VLOG(verbosity) << " " << prefix
368 << "gradient: " << derivatives.gradient.format(heavy);
369 VLOG(verbosity) << " " << prefix
370 << "A: " << derivatives.A.format(heavy);
371 VLOG(verbosity) << " " << prefix
372 << "Ax-b: " << derivatives.Axmb.format(heavy);
373 VLOG(verbosity) << " " << prefix
374 << "f: " << derivatives.f.format(heavy);
375 VLOG(verbosity) << " " << prefix
376 << "df: " << derivatives.df.format(heavy);
377 }
378}
379
380}; // namespace solvers
381}; // namespace frc971