blob: 4a7575d6e8d5a07126bd1aac3e62fd7122a5f4f1 [file] [log] [blame]
Austin Schuh29441032023-05-31 19:32:24 -07001#ifndef FRC971_SOLVERS_CONVEX_H_
2#define FRC971_SOLVERS_CONVEX_H_
3
Austin Schuhe9ae4562023-04-25 22:18:18 -07004#include <sys/types.h>
5#include <unistd.h>
6
Austin Schuhe9ae4562023-04-25 22:18:18 -07007#include <iomanip>
8
Austin Schuh99f7c6a2024-06-25 22:07:44 -07009#include "absl/log/check.h"
10#include "absl/log/log.h"
Austin Schuhe9ae4562023-04-25 22:18:18 -070011#include "absl/strings/str_join.h"
Philipp Schrader790cb542023-07-05 21:06:52 -070012#include <Eigen/Dense>
Austin Schuhe9ae4562023-04-25 22:18:18 -070013
Stephan Pleinesd99b1ee2024-02-02 20:56:44 -080014namespace frc971::solvers {
Austin Schuhe9ae4562023-04-25 22:18:18 -070015
16// TODO(austin): Steal JET from Ceres to generate the derivatives easily and
17// quickly?
18//
19// States is the number of inputs to the optimization problem.
20// M is the number of inequality constraints.
21// N is the number of equality constraints.
22template <size_t States, size_t M, size_t N>
23class ConvexProblem {
24 public:
25 // Returns the function to minimize and it's derivatives.
26 virtual double f0(
27 Eigen::Ref<const Eigen::Matrix<double, States, 1>> X) const = 0;
28 virtual Eigen::Matrix<double, States, 1> df0(
29 Eigen::Ref<const Eigen::Matrix<double, States, 1>> X) const = 0;
30 virtual Eigen::Matrix<double, States, States> ddf0(
31 Eigen::Ref<const Eigen::Matrix<double, States, 1>> X) const = 0;
32
33 // Returns the constraints f(X) < 0, and their derivative.
34 virtual Eigen::Matrix<double, M, 1> f(
35 Eigen::Ref<const Eigen::Matrix<double, States, 1>> X) const = 0;
36 virtual Eigen::Matrix<double, M, States> df(
37 Eigen::Ref<const Eigen::Matrix<double, States, 1>> X) const = 0;
38
39 // Returns the equality constraints of the form A x = b
40 virtual Eigen::Matrix<double, N, States> A() const = 0;
41 virtual Eigen::Matrix<double, N, 1> b() const = 0;
42};
43
44// Implements a Primal-Dual Interior point method convex solver.
45// See 11.7 of https://web.stanford.edu/~boyd/cvxbook/bv_cvxbook.pdf
46//
47// States is the number of inputs to the optimization problem.
48// M is the number of inequality constraints.
49// N is the number of equality constraints.
50template <size_t States, size_t M, size_t N>
51class Solver {
52 public:
53 // Ratio to require the cost to decrease when line searching.
54 static constexpr double kAlpha = 0.05;
55 // Line search step parameter.
56 static constexpr double kBeta = 0.5;
57 static constexpr double kMu = 2.0;
58 // Terminal condition for the primal problem (equality constraints) and dual
59 // (gradient + inequality constraints).
60 static constexpr double kEpsilonF = 1e-6;
61 // Terminal condition for nu, the surrogate duality gap.
62 static constexpr double kEpsilon = 1e-6;
63
64 // Solves the problem given a feasible initial solution.
65 Eigen::Matrix<double, States, 1> Solve(
66 const ConvexProblem<States, M, N> &problem,
67 Eigen::Ref<const Eigen::Matrix<double, States, 1>> X_initial);
68
69 private:
70 // Class to hold all the derivataves and function evaluations.
71 struct Derivatives {
72 Eigen::Matrix<double, States, 1> gradient;
73 Eigen::Matrix<double, States, States> hessian;
74
75 // Inequality function f
76 Eigen::Matrix<double, M, 1> f;
77 // df
78 Eigen::Matrix<double, M, States> df;
79
80 // ddf is assumed to be 0 because for the linear constraint distance
81 // function we are using, it is actually 0, and by assuming it is zero
82 // rather than passing it through as 0 to the solver, we can save enough CPU
83 // to make it worth it.
84
85 // A
86 Eigen::Matrix<double, N, States> A;
87 // Ax - b
88 Eigen::Matrix<double, N, 1> Axmb;
89 };
90
91 // Computes all the values for the given problem at the given state.
92 Derivatives ComputeDerivative(
93 const ConvexProblem<States, M, N> &problem,
94 const Eigen::Ref<const Eigen::Matrix<double, States + M + N, 1>> y);
95
96 // Computes Rt at the given state and with the given t_inverse. See 11.53 of
97 // cvxbook.pdf.
98 Eigen::Matrix<double, States + M + N, 1> Rt(
99 const Derivatives &derivatives,
100 Eigen::Matrix<double, States + M + N, 1> y, double t_inverse);
101
102 // Prints out all the derivatives with VLOG at the provided verbosity.
103 void PrintDerivatives(
104 const Derivatives &derivatives,
105 const Eigen::Ref<const Eigen::Matrix<double, States + M + N, 1>> y,
106 std::string_view prefix, int verbosity);
107};
108
109template <size_t States, size_t M, size_t N>
110Eigen::Matrix<double, States + M + N, 1> Solver<States, M, N>::Rt(
111 const Derivatives &derivatives, Eigen::Matrix<double, States + M + N, 1> y,
112 double t_inverse) {
113 Eigen::Matrix<double, States + M + N, 1> result;
114
115 Eigen::Ref<Eigen::Matrix<double, States, 1>> r_dual =
116 result.template block<States, 1>(0, 0);
117 Eigen::Ref<Eigen::Matrix<double, M, 1>> r_cent =
118 result.template block<M, 1>(States, 0);
119 Eigen::Ref<Eigen::Matrix<double, N, 1>> r_pri =
120 result.template block<N, 1>(States + M, 0);
121
122 Eigen::Ref<const Eigen::Matrix<double, M, 1>> lambda =
123 y.template block<M, 1>(States, 0);
124 Eigen::Ref<const Eigen::Matrix<double, N, 1>> v =
125 y.template block<N, 1>(States + M, 0);
126
127 r_dual = derivatives.gradient + derivatives.df.transpose() * lambda +
128 derivatives.A.transpose() * v;
Austin Schuh29441032023-05-31 19:32:24 -0700129 r_cent = -lambda.array() * derivatives.f.array() - t_inverse;
Austin Schuhe9ae4562023-04-25 22:18:18 -0700130 r_pri = derivatives.Axmb;
131
132 return result;
133}
134
135template <size_t States, size_t M, size_t N>
136Eigen::Matrix<double, States, 1> Solver<States, M, N>::Solve(
137 const ConvexProblem<States, M, N> &problem,
138 Eigen::Ref<const Eigen::Matrix<double, States, 1>> X_initial) {
139 const Eigen::IOFormat kHeavyFormat(Eigen::StreamPrecision, 0, ", ",
140 ",\n "
141 " ",
142 "[", "]", "[", "]");
143
144 Eigen::Matrix<double, States + M + N, 1> y =
145 Eigen::Matrix<double, States + M + N, 1>::Constant(1.0);
146 y.template block<States, 1>(0, 0) = X_initial;
147
148 Derivatives derivatives = ComputeDerivative(problem, y);
149
150 for (size_t i = 0; i < M; ++i) {
151 CHECK_LE(derivatives.f(i, 0), 0.0)
152 << ": Initial state " << X_initial.transpose().format(kHeavyFormat)
153 << " not feasible";
154 }
155
156 PrintDerivatives(derivatives, y, "", 1);
157
158 size_t iteration = 0;
159 while (true) {
160 // Solve for the primal-dual search direction by solving the newton step.
161 Eigen::Ref<const Eigen::Matrix<double, M, 1>> lambda =
162 y.template block<M, 1>(States, 0);
163
164 const double nu = -(derivatives.f.transpose() * lambda)(0, 0);
165 const double t_inverse = nu / (kMu * lambda.rows());
166 Eigen::Matrix<double, States + M + N, 1> rt_orig =
167 Rt(derivatives, y, t_inverse);
168
169 Eigen::Matrix<double, States + M + N, States + M + N> m1;
170 m1.setZero();
171 m1.template block<States, States>(0, 0) = derivatives.hessian;
172 m1.template block<States, M>(0, States) = derivatives.df.transpose();
173 m1.template block<States, N>(0, States + M) = derivatives.A.transpose();
174 m1.template block<M, States>(States, 0) =
175 -(Eigen::DiagonalMatrix<double, M>(lambda) * derivatives.df);
Austin Schuh29441032023-05-31 19:32:24 -0700176 m1.template block<M, M>(States, States) =
177 Eigen::DiagonalMatrix<double, M>(-derivatives.f);
Austin Schuhe9ae4562023-04-25 22:18:18 -0700178 m1.template block<N, States>(States + M, 0) = derivatives.A;
179
180 Eigen::Matrix<double, States + M + N, 1> dy =
181 m1.colPivHouseholderQr().solve(-rt_orig);
182
183 Eigen::Ref<Eigen::Matrix<double, M, 1>> dlambda =
184 dy.template block<M, 1>(States, 0);
185
186 double s = 1.0;
187
188 // Now, time to do line search.
189 //
190 // Start by keeping lambda positive. Make sure our step doesn't let
191 // lambda cross 0.
192 for (int i = 0; i < dlambda.rows(); ++i) {
193 if (lambda(i) + s * dlambda(i) < 0.0) {
194 // Ignore tiny steps in lambda. They cause issues when we get really
195 // close to having our constraints met but haven't converged the rest
196 // of the problem and start to run into rounding issues in the matrix
197 // solve portion.
198 if (dlambda(i) < 0.0 && dlambda(i) > -1e-12) {
199 VLOG(1) << " lambda(" << i << ") " << lambda(i) << " + " << s
200 << " * " << dlambda(i) << " -> s would be now "
201 << -lambda(i) / dlambda(i);
202 dlambda(i) = 0.0;
203 VLOG(1) << " dy -> " << std::setprecision(12) << std::fixed
204 << std::setfill(' ') << dy.transpose().format(kHeavyFormat);
205 continue;
206 }
207 VLOG(1) << " lambda(" << i << ") " << lambda(i) << " + " << s << " * "
208 << dlambda(i) << " -> s now " << -lambda(i) / dlambda(i);
209 s = -lambda(i) / dlambda(i);
210 }
211 }
212
213 VLOG(1) << " After lambda line search, s is " << s;
214
215 VLOG(3) << " Initial step " << iteration << " -> " << std::setprecision(12)
216 << std::fixed << std::setfill(' ')
217 << dy.transpose().format(kHeavyFormat);
218 VLOG(3) << " rt -> "
219 << std::setprecision(12) << std::fixed << std::setfill(' ')
220 << rt_orig.transpose().format(kHeavyFormat);
221
222 const double rt_orig_squared_norm = rt_orig.squaredNorm();
223
224 Eigen::Matrix<double, States + M + N, 1> next_y;
225 Eigen::Matrix<double, States + M + N, 1> rt;
226 Derivatives next_derivatives;
227 while (true) {
228 next_y = y + s * dy;
229 next_derivatives = ComputeDerivative(problem, next_y);
230 rt = Rt(next_derivatives, next_y, t_inverse);
231
232 const Eigen::Ref<const Eigen::VectorXd> next_x =
233 next_y.block(0, 0, next_derivatives.hessian.rows(), 1);
234 const Eigen::Ref<const Eigen::VectorXd> next_lambda =
235 next_y.block(next_x.rows(), 0, next_derivatives.f.rows(), 1);
236
237 const Eigen::Ref<const Eigen::VectorXd> next_v = next_y.block(
238 next_x.rows() + next_lambda.rows(), 0, next_derivatives.A.rows(), 1);
239
240 VLOG(1) << " next_rt(" << iteration << ") is " << rt.norm() << " -> "
241 << std::setprecision(12) << std::fixed << std::setfill(' ')
242 << rt.transpose().format(kHeavyFormat);
243
244 PrintDerivatives(next_derivatives, next_y, "next_", 3);
245
246 if (next_derivatives.f.maxCoeff() > 0.0) {
247 VLOG(1) << " f_next > 0.0 -> " << next_derivatives.f.maxCoeff()
248 << ", continuing line search.";
249 s *= kBeta;
250 } else if (next_derivatives.Axmb.squaredNorm() < 0.1 &&
251 rt.squaredNorm() >
252 std::pow(1.0 - kAlpha * s, 2.0) * rt_orig_squared_norm) {
253 VLOG(1) << " |Rt| > |Rt+1| " << rt.norm() << " > " << rt_orig.norm()
254 << ", drt -> " << std::setprecision(12) << std::fixed
255 << std::setfill(' ')
256 << (rt_orig - rt).transpose().format(kHeavyFormat);
257 s *= kBeta;
258 } else {
259 break;
260 }
261 }
262
263 VLOG(1) << " Terminated line search with s " << s << ", " << rt.norm()
264 << "(|Rt+1|) < " << rt_orig.norm() << "(|Rt|)";
265 y = next_y;
266
267 const Eigen::Ref<const Eigen::VectorXd> next_lambda =
268 y.template block<M, 1>(States, 0);
269
270 // See if we hit our convergence criteria.
271 const double r_primal_squared_norm =
272 rt.template block<N, 1>(States + M, 0).squaredNorm();
273 VLOG(1) << " rt_next(" << iteration << ") is " << rt.norm() << " -> "
274 << std::setprecision(12) << std::fixed << std::setfill(' ')
275 << rt.transpose().format(kHeavyFormat);
276 if (r_primal_squared_norm < kEpsilonF * kEpsilonF) {
277 const double r_dual_squared_norm =
278 rt.template block<States, 1>(0, 0).squaredNorm();
279 if (r_dual_squared_norm < kEpsilonF * kEpsilonF) {
280 const double next_nu =
281 -(next_derivatives.f.transpose() * next_lambda)(0, 0);
282 if (next_nu < kEpsilon) {
283 VLOG(1) << " r_primal(" << iteration << ") -> "
284 << std::sqrt(r_primal_squared_norm) << " < " << kEpsilonF
285 << ", r_dual(" << iteration << ") -> "
286 << std::sqrt(r_dual_squared_norm) << " < " << kEpsilonF
287 << ", nu(" << iteration << ") -> " << next_nu << " < "
288 << kEpsilon;
289 break;
290 } else {
291 VLOG(1) << " nu(" << iteration << ") -> " << next_nu << " < "
292 << kEpsilon << ", not done yet";
293 }
294
295 } else {
296 VLOG(1) << " r_dual(" << iteration << ") -> "
297 << std::sqrt(r_dual_squared_norm) << " < " << kEpsilonF
298 << ", not done yet";
299 }
300 } else {
301 VLOG(1) << " r_primal(" << iteration << ") -> "
302 << std::sqrt(r_primal_squared_norm) << " < " << kEpsilonF
303 << ", not done yet";
304 }
305 VLOG(1) << " step(" << iteration << ") " << std::setprecision(12)
306 << (s * dy).transpose().format(kHeavyFormat);
307 VLOG(1) << " y(" << iteration << ") is now " << std::setprecision(12)
308 << y.transpose().format(kHeavyFormat);
309
310 // Very import, use the last set of derivatives we picked for our new y
311 // for the next iteration. This avoids re-computing it.
312 derivatives = std::move(next_derivatives);
313
314 ++iteration;
315 if (iteration > 100) {
316 LOG(FATAL) << "Too many iterations";
317 }
318 }
319
320 return y.template block<States, 1>(0, 0);
321}
322
323template <size_t States, size_t M, size_t N>
324typename Solver<States, M, N>::Derivatives
325Solver<States, M, N>::ComputeDerivative(
326 const ConvexProblem<States, M, N> &problem,
327 const Eigen::Ref<const Eigen::Matrix<double, States + M + N, 1>> y) {
328 const Eigen::Ref<const Eigen::Matrix<double, States, 1>> x =
329 y.template block<States, 1>(0, 0);
330
331 Derivatives derivatives;
332 derivatives.gradient = problem.df0(x);
333 derivatives.hessian = problem.ddf0(x);
334 derivatives.f = problem.f(x);
335 derivatives.df = problem.df(x);
336 derivatives.A = problem.A();
337 derivatives.Axmb =
338 derivatives.A * y.template block<States, 1>(0, 0) - problem.b();
339 return derivatives;
340}
341
342template <size_t States, size_t M, size_t N>
343void Solver<States, M, N>::PrintDerivatives(
344 const Derivatives &derivatives,
345 const Eigen::Ref<const Eigen::Matrix<double, States + M + N, 1>> y,
346 std::string_view prefix, int verbosity) {
347 const Eigen::Ref<const Eigen::VectorXd> x =
348 y.block(0, 0, derivatives.hessian.rows(), 1);
349 const Eigen::Ref<const Eigen::VectorXd> lambda =
350 y.block(x.rows(), 0, derivatives.f.rows(), 1);
351
352 if (VLOG_IS_ON(verbosity)) {
353 Eigen::IOFormat heavy(Eigen::StreamPrecision, 0, ", ",
354 ",\n "
355 " ",
356 "[", "]", "[", "]");
357 heavy.rowSeparator =
358 heavy.rowSeparator +
359 std::string(absl::StrCat(getpid()).size() + prefix.size(), ' ');
360
361 const Eigen::Ref<const Eigen::VectorXd> v =
362 y.block(x.rows() + lambda.rows(), 0, derivatives.A.rows(), 1);
363 VLOG(verbosity) << " " << prefix << "x: " << x.transpose().format(heavy);
364 VLOG(verbosity) << " " << prefix
365 << "lambda: " << lambda.transpose().format(heavy);
366 VLOG(verbosity) << " " << prefix << "v: " << v.transpose().format(heavy);
367 VLOG(verbosity) << " " << prefix
368 << "hessian: " << derivatives.hessian.format(heavy);
369 VLOG(verbosity) << " " << prefix
370 << "gradient: " << derivatives.gradient.format(heavy);
371 VLOG(verbosity) << " " << prefix
372 << "A: " << derivatives.A.format(heavy);
373 VLOG(verbosity) << " " << prefix
374 << "Ax-b: " << derivatives.Axmb.format(heavy);
375 VLOG(verbosity) << " " << prefix
376 << "f: " << derivatives.f.format(heavy);
377 VLOG(verbosity) << " " << prefix
378 << "df: " << derivatives.df.format(heavy);
379 }
380}
381
Stephan Pleinesd99b1ee2024-02-02 20:56:44 -0800382} // namespace frc971::solvers
Austin Schuh29441032023-05-31 19:32:24 -0700383
384#endif // FRC971_SOLVERS_CONVEX_H_