blob: d8989d8cf69d761bac36e2ed4d0a3f12d5953af0 [file] [log] [blame]
Austin Schuh29441032023-05-31 19:32:24 -07001#ifndef FRC971_SOLVERS_CONVEX_H_
2#define FRC971_SOLVERS_CONVEX_H_
3
Austin Schuhe9ae4562023-04-25 22:18:18 -07004#include <sys/types.h>
5#include <unistd.h>
6
Austin Schuhe9ae4562023-04-25 22:18:18 -07007#include <iomanip>
8
9#include "absl/strings/str_join.h"
10#include "glog/logging.h"
Philipp Schrader790cb542023-07-05 21:06:52 -070011#include <Eigen/Dense>
Austin Schuhe9ae4562023-04-25 22:18:18 -070012
Stephan Pleinesd99b1ee2024-02-02 20:56:44 -080013namespace frc971::solvers {
Austin Schuhe9ae4562023-04-25 22:18:18 -070014
15// TODO(austin): Steal JET from Ceres to generate the derivatives easily and
16// quickly?
17//
18// States is the number of inputs to the optimization problem.
19// M is the number of inequality constraints.
20// N is the number of equality constraints.
21template <size_t States, size_t M, size_t N>
22class ConvexProblem {
23 public:
24 // Returns the function to minimize and it's derivatives.
25 virtual double f0(
26 Eigen::Ref<const Eigen::Matrix<double, States, 1>> X) const = 0;
27 virtual Eigen::Matrix<double, States, 1> df0(
28 Eigen::Ref<const Eigen::Matrix<double, States, 1>> X) const = 0;
29 virtual Eigen::Matrix<double, States, States> ddf0(
30 Eigen::Ref<const Eigen::Matrix<double, States, 1>> X) const = 0;
31
32 // Returns the constraints f(X) < 0, and their derivative.
33 virtual Eigen::Matrix<double, M, 1> f(
34 Eigen::Ref<const Eigen::Matrix<double, States, 1>> X) const = 0;
35 virtual Eigen::Matrix<double, M, States> df(
36 Eigen::Ref<const Eigen::Matrix<double, States, 1>> X) const = 0;
37
38 // Returns the equality constraints of the form A x = b
39 virtual Eigen::Matrix<double, N, States> A() const = 0;
40 virtual Eigen::Matrix<double, N, 1> b() const = 0;
41};
42
43// Implements a Primal-Dual Interior point method convex solver.
44// See 11.7 of https://web.stanford.edu/~boyd/cvxbook/bv_cvxbook.pdf
45//
46// States is the number of inputs to the optimization problem.
47// M is the number of inequality constraints.
48// N is the number of equality constraints.
49template <size_t States, size_t M, size_t N>
50class Solver {
51 public:
52 // Ratio to require the cost to decrease when line searching.
53 static constexpr double kAlpha = 0.05;
54 // Line search step parameter.
55 static constexpr double kBeta = 0.5;
56 static constexpr double kMu = 2.0;
57 // Terminal condition for the primal problem (equality constraints) and dual
58 // (gradient + inequality constraints).
59 static constexpr double kEpsilonF = 1e-6;
60 // Terminal condition for nu, the surrogate duality gap.
61 static constexpr double kEpsilon = 1e-6;
62
63 // Solves the problem given a feasible initial solution.
64 Eigen::Matrix<double, States, 1> Solve(
65 const ConvexProblem<States, M, N> &problem,
66 Eigen::Ref<const Eigen::Matrix<double, States, 1>> X_initial);
67
68 private:
69 // Class to hold all the derivataves and function evaluations.
70 struct Derivatives {
71 Eigen::Matrix<double, States, 1> gradient;
72 Eigen::Matrix<double, States, States> hessian;
73
74 // Inequality function f
75 Eigen::Matrix<double, M, 1> f;
76 // df
77 Eigen::Matrix<double, M, States> df;
78
79 // ddf is assumed to be 0 because for the linear constraint distance
80 // function we are using, it is actually 0, and by assuming it is zero
81 // rather than passing it through as 0 to the solver, we can save enough CPU
82 // to make it worth it.
83
84 // A
85 Eigen::Matrix<double, N, States> A;
86 // Ax - b
87 Eigen::Matrix<double, N, 1> Axmb;
88 };
89
90 // Computes all the values for the given problem at the given state.
91 Derivatives ComputeDerivative(
92 const ConvexProblem<States, M, N> &problem,
93 const Eigen::Ref<const Eigen::Matrix<double, States + M + N, 1>> y);
94
95 // Computes Rt at the given state and with the given t_inverse. See 11.53 of
96 // cvxbook.pdf.
97 Eigen::Matrix<double, States + M + N, 1> Rt(
98 const Derivatives &derivatives,
99 Eigen::Matrix<double, States + M + N, 1> y, double t_inverse);
100
101 // Prints out all the derivatives with VLOG at the provided verbosity.
102 void PrintDerivatives(
103 const Derivatives &derivatives,
104 const Eigen::Ref<const Eigen::Matrix<double, States + M + N, 1>> y,
105 std::string_view prefix, int verbosity);
106};
107
108template <size_t States, size_t M, size_t N>
109Eigen::Matrix<double, States + M + N, 1> Solver<States, M, N>::Rt(
110 const Derivatives &derivatives, Eigen::Matrix<double, States + M + N, 1> y,
111 double t_inverse) {
112 Eigen::Matrix<double, States + M + N, 1> result;
113
114 Eigen::Ref<Eigen::Matrix<double, States, 1>> r_dual =
115 result.template block<States, 1>(0, 0);
116 Eigen::Ref<Eigen::Matrix<double, M, 1>> r_cent =
117 result.template block<M, 1>(States, 0);
118 Eigen::Ref<Eigen::Matrix<double, N, 1>> r_pri =
119 result.template block<N, 1>(States + M, 0);
120
121 Eigen::Ref<const Eigen::Matrix<double, M, 1>> lambda =
122 y.template block<M, 1>(States, 0);
123 Eigen::Ref<const Eigen::Matrix<double, N, 1>> v =
124 y.template block<N, 1>(States + M, 0);
125
126 r_dual = derivatives.gradient + derivatives.df.transpose() * lambda +
127 derivatives.A.transpose() * v;
Austin Schuh29441032023-05-31 19:32:24 -0700128 r_cent = -lambda.array() * derivatives.f.array() - t_inverse;
Austin Schuhe9ae4562023-04-25 22:18:18 -0700129 r_pri = derivatives.Axmb;
130
131 return result;
132}
133
134template <size_t States, size_t M, size_t N>
135Eigen::Matrix<double, States, 1> Solver<States, M, N>::Solve(
136 const ConvexProblem<States, M, N> &problem,
137 Eigen::Ref<const Eigen::Matrix<double, States, 1>> X_initial) {
138 const Eigen::IOFormat kHeavyFormat(Eigen::StreamPrecision, 0, ", ",
139 ",\n "
140 " ",
141 "[", "]", "[", "]");
142
143 Eigen::Matrix<double, States + M + N, 1> y =
144 Eigen::Matrix<double, States + M + N, 1>::Constant(1.0);
145 y.template block<States, 1>(0, 0) = X_initial;
146
147 Derivatives derivatives = ComputeDerivative(problem, y);
148
149 for (size_t i = 0; i < M; ++i) {
150 CHECK_LE(derivatives.f(i, 0), 0.0)
151 << ": Initial state " << X_initial.transpose().format(kHeavyFormat)
152 << " not feasible";
153 }
154
155 PrintDerivatives(derivatives, y, "", 1);
156
157 size_t iteration = 0;
158 while (true) {
159 // Solve for the primal-dual search direction by solving the newton step.
160 Eigen::Ref<const Eigen::Matrix<double, M, 1>> lambda =
161 y.template block<M, 1>(States, 0);
162
163 const double nu = -(derivatives.f.transpose() * lambda)(0, 0);
164 const double t_inverse = nu / (kMu * lambda.rows());
165 Eigen::Matrix<double, States + M + N, 1> rt_orig =
166 Rt(derivatives, y, t_inverse);
167
168 Eigen::Matrix<double, States + M + N, States + M + N> m1;
169 m1.setZero();
170 m1.template block<States, States>(0, 0) = derivatives.hessian;
171 m1.template block<States, M>(0, States) = derivatives.df.transpose();
172 m1.template block<States, N>(0, States + M) = derivatives.A.transpose();
173 m1.template block<M, States>(States, 0) =
174 -(Eigen::DiagonalMatrix<double, M>(lambda) * derivatives.df);
Austin Schuh29441032023-05-31 19:32:24 -0700175 m1.template block<M, M>(States, States) =
176 Eigen::DiagonalMatrix<double, M>(-derivatives.f);
Austin Schuhe9ae4562023-04-25 22:18:18 -0700177 m1.template block<N, States>(States + M, 0) = derivatives.A;
178
179 Eigen::Matrix<double, States + M + N, 1> dy =
180 m1.colPivHouseholderQr().solve(-rt_orig);
181
182 Eigen::Ref<Eigen::Matrix<double, M, 1>> dlambda =
183 dy.template block<M, 1>(States, 0);
184
185 double s = 1.0;
186
187 // Now, time to do line search.
188 //
189 // Start by keeping lambda positive. Make sure our step doesn't let
190 // lambda cross 0.
191 for (int i = 0; i < dlambda.rows(); ++i) {
192 if (lambda(i) + s * dlambda(i) < 0.0) {
193 // Ignore tiny steps in lambda. They cause issues when we get really
194 // close to having our constraints met but haven't converged the rest
195 // of the problem and start to run into rounding issues in the matrix
196 // solve portion.
197 if (dlambda(i) < 0.0 && dlambda(i) > -1e-12) {
198 VLOG(1) << " lambda(" << i << ") " << lambda(i) << " + " << s
199 << " * " << dlambda(i) << " -> s would be now "
200 << -lambda(i) / dlambda(i);
201 dlambda(i) = 0.0;
202 VLOG(1) << " dy -> " << std::setprecision(12) << std::fixed
203 << std::setfill(' ') << dy.transpose().format(kHeavyFormat);
204 continue;
205 }
206 VLOG(1) << " lambda(" << i << ") " << lambda(i) << " + " << s << " * "
207 << dlambda(i) << " -> s now " << -lambda(i) / dlambda(i);
208 s = -lambda(i) / dlambda(i);
209 }
210 }
211
212 VLOG(1) << " After lambda line search, s is " << s;
213
214 VLOG(3) << " Initial step " << iteration << " -> " << std::setprecision(12)
215 << std::fixed << std::setfill(' ')
216 << dy.transpose().format(kHeavyFormat);
217 VLOG(3) << " rt -> "
218 << std::setprecision(12) << std::fixed << std::setfill(' ')
219 << rt_orig.transpose().format(kHeavyFormat);
220
221 const double rt_orig_squared_norm = rt_orig.squaredNorm();
222
223 Eigen::Matrix<double, States + M + N, 1> next_y;
224 Eigen::Matrix<double, States + M + N, 1> rt;
225 Derivatives next_derivatives;
226 while (true) {
227 next_y = y + s * dy;
228 next_derivatives = ComputeDerivative(problem, next_y);
229 rt = Rt(next_derivatives, next_y, t_inverse);
230
231 const Eigen::Ref<const Eigen::VectorXd> next_x =
232 next_y.block(0, 0, next_derivatives.hessian.rows(), 1);
233 const Eigen::Ref<const Eigen::VectorXd> next_lambda =
234 next_y.block(next_x.rows(), 0, next_derivatives.f.rows(), 1);
235
236 const Eigen::Ref<const Eigen::VectorXd> next_v = next_y.block(
237 next_x.rows() + next_lambda.rows(), 0, next_derivatives.A.rows(), 1);
238
239 VLOG(1) << " next_rt(" << iteration << ") is " << rt.norm() << " -> "
240 << std::setprecision(12) << std::fixed << std::setfill(' ')
241 << rt.transpose().format(kHeavyFormat);
242
243 PrintDerivatives(next_derivatives, next_y, "next_", 3);
244
245 if (next_derivatives.f.maxCoeff() > 0.0) {
246 VLOG(1) << " f_next > 0.0 -> " << next_derivatives.f.maxCoeff()
247 << ", continuing line search.";
248 s *= kBeta;
249 } else if (next_derivatives.Axmb.squaredNorm() < 0.1 &&
250 rt.squaredNorm() >
251 std::pow(1.0 - kAlpha * s, 2.0) * rt_orig_squared_norm) {
252 VLOG(1) << " |Rt| > |Rt+1| " << rt.norm() << " > " << rt_orig.norm()
253 << ", drt -> " << std::setprecision(12) << std::fixed
254 << std::setfill(' ')
255 << (rt_orig - rt).transpose().format(kHeavyFormat);
256 s *= kBeta;
257 } else {
258 break;
259 }
260 }
261
262 VLOG(1) << " Terminated line search with s " << s << ", " << rt.norm()
263 << "(|Rt+1|) < " << rt_orig.norm() << "(|Rt|)";
264 y = next_y;
265
266 const Eigen::Ref<const Eigen::VectorXd> next_lambda =
267 y.template block<M, 1>(States, 0);
268
269 // See if we hit our convergence criteria.
270 const double r_primal_squared_norm =
271 rt.template block<N, 1>(States + M, 0).squaredNorm();
272 VLOG(1) << " rt_next(" << iteration << ") is " << rt.norm() << " -> "
273 << std::setprecision(12) << std::fixed << std::setfill(' ')
274 << rt.transpose().format(kHeavyFormat);
275 if (r_primal_squared_norm < kEpsilonF * kEpsilonF) {
276 const double r_dual_squared_norm =
277 rt.template block<States, 1>(0, 0).squaredNorm();
278 if (r_dual_squared_norm < kEpsilonF * kEpsilonF) {
279 const double next_nu =
280 -(next_derivatives.f.transpose() * next_lambda)(0, 0);
281 if (next_nu < kEpsilon) {
282 VLOG(1) << " r_primal(" << iteration << ") -> "
283 << std::sqrt(r_primal_squared_norm) << " < " << kEpsilonF
284 << ", r_dual(" << iteration << ") -> "
285 << std::sqrt(r_dual_squared_norm) << " < " << kEpsilonF
286 << ", nu(" << iteration << ") -> " << next_nu << " < "
287 << kEpsilon;
288 break;
289 } else {
290 VLOG(1) << " nu(" << iteration << ") -> " << next_nu << " < "
291 << kEpsilon << ", not done yet";
292 }
293
294 } else {
295 VLOG(1) << " r_dual(" << iteration << ") -> "
296 << std::sqrt(r_dual_squared_norm) << " < " << kEpsilonF
297 << ", not done yet";
298 }
299 } else {
300 VLOG(1) << " r_primal(" << iteration << ") -> "
301 << std::sqrt(r_primal_squared_norm) << " < " << kEpsilonF
302 << ", not done yet";
303 }
304 VLOG(1) << " step(" << iteration << ") " << std::setprecision(12)
305 << (s * dy).transpose().format(kHeavyFormat);
306 VLOG(1) << " y(" << iteration << ") is now " << std::setprecision(12)
307 << y.transpose().format(kHeavyFormat);
308
309 // Very import, use the last set of derivatives we picked for our new y
310 // for the next iteration. This avoids re-computing it.
311 derivatives = std::move(next_derivatives);
312
313 ++iteration;
314 if (iteration > 100) {
315 LOG(FATAL) << "Too many iterations";
316 }
317 }
318
319 return y.template block<States, 1>(0, 0);
320}
321
322template <size_t States, size_t M, size_t N>
323typename Solver<States, M, N>::Derivatives
324Solver<States, M, N>::ComputeDerivative(
325 const ConvexProblem<States, M, N> &problem,
326 const Eigen::Ref<const Eigen::Matrix<double, States + M + N, 1>> y) {
327 const Eigen::Ref<const Eigen::Matrix<double, States, 1>> x =
328 y.template block<States, 1>(0, 0);
329
330 Derivatives derivatives;
331 derivatives.gradient = problem.df0(x);
332 derivatives.hessian = problem.ddf0(x);
333 derivatives.f = problem.f(x);
334 derivatives.df = problem.df(x);
335 derivatives.A = problem.A();
336 derivatives.Axmb =
337 derivatives.A * y.template block<States, 1>(0, 0) - problem.b();
338 return derivatives;
339}
340
341template <size_t States, size_t M, size_t N>
342void Solver<States, M, N>::PrintDerivatives(
343 const Derivatives &derivatives,
344 const Eigen::Ref<const Eigen::Matrix<double, States + M + N, 1>> y,
345 std::string_view prefix, int verbosity) {
346 const Eigen::Ref<const Eigen::VectorXd> x =
347 y.block(0, 0, derivatives.hessian.rows(), 1);
348 const Eigen::Ref<const Eigen::VectorXd> lambda =
349 y.block(x.rows(), 0, derivatives.f.rows(), 1);
350
351 if (VLOG_IS_ON(verbosity)) {
352 Eigen::IOFormat heavy(Eigen::StreamPrecision, 0, ", ",
353 ",\n "
354 " ",
355 "[", "]", "[", "]");
356 heavy.rowSeparator =
357 heavy.rowSeparator +
358 std::string(absl::StrCat(getpid()).size() + prefix.size(), ' ');
359
360 const Eigen::Ref<const Eigen::VectorXd> v =
361 y.block(x.rows() + lambda.rows(), 0, derivatives.A.rows(), 1);
362 VLOG(verbosity) << " " << prefix << "x: " << x.transpose().format(heavy);
363 VLOG(verbosity) << " " << prefix
364 << "lambda: " << lambda.transpose().format(heavy);
365 VLOG(verbosity) << " " << prefix << "v: " << v.transpose().format(heavy);
366 VLOG(verbosity) << " " << prefix
367 << "hessian: " << derivatives.hessian.format(heavy);
368 VLOG(verbosity) << " " << prefix
369 << "gradient: " << derivatives.gradient.format(heavy);
370 VLOG(verbosity) << " " << prefix
371 << "A: " << derivatives.A.format(heavy);
372 VLOG(verbosity) << " " << prefix
373 << "Ax-b: " << derivatives.Axmb.format(heavy);
374 VLOG(verbosity) << " " << prefix
375 << "f: " << derivatives.f.format(heavy);
376 VLOG(verbosity) << " " << prefix
377 << "df: " << derivatives.df.format(heavy);
378 }
379}
380
Stephan Pleinesd99b1ee2024-02-02 20:56:44 -0800381} // namespace frc971::solvers
Austin Schuh29441032023-05-31 19:32:24 -0700382
383#endif // FRC971_SOLVERS_CONVEX_H_