Add casadi generator to symengine code
This lets us use casadi for nonlinear optimization instead of
reinventing our thorugh symengine.
Change-Id: I56af353ae3235ce775cb743c522424762dafbdaf
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/generate_physics.cc b/frc971/control_loops/swerve/generate_physics.cc
index d8c1321..6f23c47 100644
--- a/frc971/control_loops/swerve/generate_physics.cc
+++ b/frc971/control_loops/swerve/generate_physics.cc
@@ -28,6 +28,8 @@
DEFINE_string(cc_output_path, "", "Path to write generated header code to");
DEFINE_string(h_output_path, "", "Path to write generated cc code to");
DEFINE_string(py_output_path, "", "Path to write generated py code to");
+DEFINE_string(casadi_py_output_path, "",
+ "Path to write casadi generated py code to");
DEFINE_bool(symbolic, false, "If true, write everything out symbolically.");
@@ -221,18 +223,18 @@
result_py.emplace_back("#!/usr/bin/python3");
result_py.emplace_back("");
result_py.emplace_back("import numpy");
- result_py.emplace_back("import math");
- result_py.emplace_back("from math import sin, cos, fabs");
result_py.emplace_back("");
- result_py.emplace_back("def atan2(y, x):");
- result_py.emplace_back(" if x < 0:");
- result_py.emplace_back(" return -math.atan2(y, x)");
- result_py.emplace_back(" else:");
- result_py.emplace_back(" return math.atan2(y, x)");
-
result_py.emplace_back("def swerve_physics(t, X, U_func):");
- // result_py.emplace_back(" print(X)");
+ result_py.emplace_back(" def atan2(y, x):");
+ result_py.emplace_back(" if x < 0:");
+ result_py.emplace_back(" return -numpy.atan2(y, x)");
+ result_py.emplace_back(" else:");
+ result_py.emplace_back(" return numpy.atan2(y, x)");
+ result_py.emplace_back(" sin = numpy.sin");
+ result_py.emplace_back(" cos = numpy.cos");
+ result_py.emplace_back(" fabs = numpy.fabs");
+
result_py.emplace_back(" result = numpy.empty([25, 1])");
result_py.emplace_back(" X = X.reshape(25, 1)");
result_py.emplace_back(" U = U_func(X)");
@@ -321,7 +323,7 @@
std::vector<std::string> result_cc;
std::vector<std::string> result_h;
- std::string_view include_guard_stripped = FLAGS_h_output_path;
+ std::string_view include_guard_stripped = h_path;
CHECK(absl::ConsumePrefix(&include_guard_stripped, FLAGS_output_base));
std::string include_guard =
absl::StrReplaceAll(absl::AsciiStrToUpper(include_guard_stripped),
@@ -450,6 +452,109 @@
aos::util::WriteStringToFileOrDie(h_path, absl::StrJoin(result_h, "\n"));
}
+ // Writes the physics out to the provided .cc and .h path.
+ void WriteCasadi(std::string_view py_path) {
+ std::vector<std::string> result_py;
+
+ // Write out the header.
+ result_py.emplace_back("#!/usr/bin/python3");
+ result_py.emplace_back("");
+ result_py.emplace_back("import casadi");
+ result_py.emplace_back("");
+ result_py.emplace_back("# Returns the derivative of our state vector");
+ result_py.emplace_back("# Returns the derivative of our state vector");
+ result_py.emplace_back("# [thetas0, thetad0, omegas0, omegad0,");
+ result_py.emplace_back("# thetas1, thetad1, omegas1, omegad1,");
+ result_py.emplace_back("# thetas2, thetad2, omegas2, omegad2,");
+ result_py.emplace_back("# thetas3, thetad3, omegas3, omegad3,");
+ result_py.emplace_back("# x, y, theta, vx, vy, omega,");
+ result_py.emplace_back("# Fx, Fy, Moment]");
+ result_py.emplace_back("def swerve_physics(X, U):");
+ result_py.emplace_back(" sin = casadi.sin");
+ result_py.emplace_back(" cos = casadi.cos");
+ result_py.emplace_back(" atan2 = casadi.atan2");
+
+ // Start by writing out variables matching each of the symbol names we use
+ // so we don't have to modify the computed equations too much.
+ for (size_t m = 0; m < kNumModules; ++m) {
+ result_py.emplace_back(
+ absl::Substitute(" thetas$0 = X[$1, 0]", m, m * 4));
+ result_py.emplace_back(
+ absl::Substitute(" omegas$0 = X[$1, 0]", m, m * 4 + 2));
+ result_py.emplace_back(
+ absl::Substitute(" omegad$0 = X[$1, 0]", m, m * 4 + 3));
+ }
+
+ result_py.emplace_back(
+ absl::Substitute(" theta = X[$0, 0]", kNumModules * 4 + 2));
+ result_py.emplace_back(
+ absl::Substitute(" vx = X[$0, 0]", kNumModules * 4 + 3));
+ result_py.emplace_back(
+ absl::Substitute(" vy = X[$0, 0]", kNumModules * 4 + 4));
+ result_py.emplace_back(
+ absl::Substitute(" omega = X[$0, 0]", kNumModules * 4 + 5));
+
+ result_py.emplace_back(
+ absl::Substitute(" fx = X[$0, 0]", kNumModules * 4 + 6));
+ result_py.emplace_back(
+ absl::Substitute(" fy = X[$0, 0]", kNumModules * 4 + 7));
+ result_py.emplace_back(
+ absl::Substitute(" moment = X[$0, 0]", kNumModules * 4 + 8));
+
+ // Now do the same for the inputs.
+ for (size_t m = 0; m < kNumModules; ++m) {
+ result_py.emplace_back(absl::Substitute(" Is$0 = U[$1, 0]", m, m * 2));
+ result_py.emplace_back(
+ absl::Substitute(" Id$0 = U[$1, 0]", m, m * 2 + 1));
+ }
+
+ result_py.emplace_back("");
+ result_py.emplace_back(" result = casadi.SX.sym('result', 25, 1)");
+ result_py.emplace_back("");
+
+ // And then write out the derivative of each state.
+ for (size_t m = 0; m < kNumModules; ++m) {
+ result_py.emplace_back(
+ absl::Substitute(" result[$0, 0] = omegas$1", m * 4, m));
+ result_py.emplace_back(
+ absl::Substitute(" result[$0, 0] = omegad$1", m * 4 + 1, m));
+
+ result_py.emplace_back(absl::Substitute(
+ " result[$0, 0] = $1", m * 4 + 2, ccode(*modules_[m].alphas_eqn)));
+ result_py.emplace_back(absl::Substitute(
+ " result[$0, 0] = $1", m * 4 + 3, ccode(*modules_[m].alphad_eqn)));
+ }
+
+ result_py.emplace_back(
+ absl::Substitute(" result[$0, 0] = omega", kNumModules * 4));
+ result_py.emplace_back(
+ absl::Substitute(" result[$0, 0] = vx", kNumModules * 4 + 1));
+ result_py.emplace_back(
+ absl::Substitute(" result[$0, 0] = vy", kNumModules * 4 + 2));
+
+ result_py.emplace_back(absl::Substitute(
+ " result[$0, 0] = $1", kNumModules * 4 + 3, ccode(*angular_accel_)));
+ result_py.emplace_back(absl::Substitute(" result[$0, 0] = $1",
+ kNumModules * 4 + 4,
+ ccode(*accel_.get(0, 0))));
+ result_py.emplace_back(absl::Substitute(" result[$0, 0] = $1",
+ kNumModules * 4 + 5,
+ ccode(*accel_.get(1, 0))));
+
+ result_py.emplace_back(
+ absl::Substitute(" result[$0, 0] = 0.0", kNumModules * 4 + 6));
+ result_py.emplace_back(
+ absl::Substitute(" result[$0, 0] = 0.0", kNumModules * 4 + 7));
+ result_py.emplace_back(
+ absl::Substitute(" result[$0, 0] = 0.0", kNumModules * 4 + 8));
+
+ result_py.emplace_back("");
+ result_py.emplace_back(
+ " return casadi.Function('xdot', [X, U], [result])");
+
+ aos::util::WriteStringToFileOrDie(py_path, absl::StrJoin(result_py, "\n"));
+ }
+
private:
static constexpr uint8_t kNumModules = 4;
@@ -671,11 +776,15 @@
frc971::control_loops::swerve::SwerveSimulation sim;
- if (!FLAGS_cc_output_path.empty() && !FLAGS_h_output_path.empty() &&
- !FLAGS_py_output_path.empty()) {
+ if (!FLAGS_cc_output_path.empty() && !FLAGS_h_output_path.empty()) {
sim.Write(FLAGS_cc_output_path, FLAGS_h_output_path);
+ }
+ if (!FLAGS_py_output_path.empty()) {
sim.WritePy(FLAGS_py_output_path);
}
+ if (!FLAGS_casadi_py_output_path.empty()) {
+ sim.WriteCasadi(FLAGS_casadi_py_output_path);
+ }
return 0;
}