Add casadi generator to symengine code

This lets us use casadi for nonlinear optimization instead of
reinventing our thorugh symengine.

Change-Id: I56af353ae3235ce775cb743c522424762dafbdaf
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/generate_physics.cc b/frc971/control_loops/swerve/generate_physics.cc
index d8c1321..6f23c47 100644
--- a/frc971/control_loops/swerve/generate_physics.cc
+++ b/frc971/control_loops/swerve/generate_physics.cc
@@ -28,6 +28,8 @@
 DEFINE_string(cc_output_path, "", "Path to write generated header code to");
 DEFINE_string(h_output_path, "", "Path to write generated cc code to");
 DEFINE_string(py_output_path, "", "Path to write generated py code to");
+DEFINE_string(casadi_py_output_path, "",
+              "Path to write casadi generated py code to");
 
 DEFINE_bool(symbolic, false, "If true, write everything out symbolically.");
 
@@ -221,18 +223,18 @@
     result_py.emplace_back("#!/usr/bin/python3");
     result_py.emplace_back("");
     result_py.emplace_back("import numpy");
-    result_py.emplace_back("import math");
-    result_py.emplace_back("from math import sin, cos, fabs");
     result_py.emplace_back("");
 
-    result_py.emplace_back("def atan2(y, x):");
-    result_py.emplace_back("    if x < 0:");
-    result_py.emplace_back("        return -math.atan2(y, x)");
-    result_py.emplace_back("    else:");
-    result_py.emplace_back("        return math.atan2(y, x)");
-
     result_py.emplace_back("def swerve_physics(t, X, U_func):");
-    // result_py.emplace_back("    print(X)");
+    result_py.emplace_back("    def atan2(y, x):");
+    result_py.emplace_back("        if x < 0:");
+    result_py.emplace_back("            return -numpy.atan2(y, x)");
+    result_py.emplace_back("        else:");
+    result_py.emplace_back("            return numpy.atan2(y, x)");
+    result_py.emplace_back("    sin = numpy.sin");
+    result_py.emplace_back("    cos = numpy.cos");
+    result_py.emplace_back("    fabs = numpy.fabs");
+
     result_py.emplace_back("    result = numpy.empty([25, 1])");
     result_py.emplace_back("    X = X.reshape(25, 1)");
     result_py.emplace_back("    U = U_func(X)");
@@ -321,7 +323,7 @@
     std::vector<std::string> result_cc;
     std::vector<std::string> result_h;
 
-    std::string_view include_guard_stripped = FLAGS_h_output_path;
+    std::string_view include_guard_stripped = h_path;
     CHECK(absl::ConsumePrefix(&include_guard_stripped, FLAGS_output_base));
     std::string include_guard =
         absl::StrReplaceAll(absl::AsciiStrToUpper(include_guard_stripped),
@@ -450,6 +452,109 @@
     aos::util::WriteStringToFileOrDie(h_path, absl::StrJoin(result_h, "\n"));
   }
 
+  // Writes the physics out to the provided .cc and .h path.
+  void WriteCasadi(std::string_view py_path) {
+    std::vector<std::string> result_py;
+
+    // Write out the header.
+    result_py.emplace_back("#!/usr/bin/python3");
+    result_py.emplace_back("");
+    result_py.emplace_back("import casadi");
+    result_py.emplace_back("");
+    result_py.emplace_back("# Returns the derivative of our state vector");
+    result_py.emplace_back("# Returns the derivative of our state vector");
+    result_py.emplace_back("# [thetas0, thetad0, omegas0, omegad0,");
+    result_py.emplace_back("#  thetas1, thetad1, omegas1, omegad1,");
+    result_py.emplace_back("#  thetas2, thetad2, omegas2, omegad2,");
+    result_py.emplace_back("#  thetas3, thetad3, omegas3, omegad3,");
+    result_py.emplace_back("#  x, y, theta, vx, vy, omega,");
+    result_py.emplace_back("#  Fx, Fy, Moment]");
+    result_py.emplace_back("def swerve_physics(X, U):");
+    result_py.emplace_back("    sin = casadi.sin");
+    result_py.emplace_back("    cos = casadi.cos");
+    result_py.emplace_back("    atan2 = casadi.atan2");
+
+    // Start by writing out variables matching each of the symbol names we use
+    // so we don't have to modify the computed equations too much.
+    for (size_t m = 0; m < kNumModules; ++m) {
+      result_py.emplace_back(
+          absl::Substitute("    thetas$0 = X[$1, 0]", m, m * 4));
+      result_py.emplace_back(
+          absl::Substitute("    omegas$0 = X[$1, 0]", m, m * 4 + 2));
+      result_py.emplace_back(
+          absl::Substitute("    omegad$0 = X[$1, 0]", m, m * 4 + 3));
+    }
+
+    result_py.emplace_back(
+        absl::Substitute("    theta = X[$0, 0]", kNumModules * 4 + 2));
+    result_py.emplace_back(
+        absl::Substitute("    vx = X[$0, 0]", kNumModules * 4 + 3));
+    result_py.emplace_back(
+        absl::Substitute("    vy = X[$0, 0]", kNumModules * 4 + 4));
+    result_py.emplace_back(
+        absl::Substitute("    omega = X[$0, 0]", kNumModules * 4 + 5));
+
+    result_py.emplace_back(
+        absl::Substitute("    fx = X[$0, 0]", kNumModules * 4 + 6));
+    result_py.emplace_back(
+        absl::Substitute("    fy = X[$0, 0]", kNumModules * 4 + 7));
+    result_py.emplace_back(
+        absl::Substitute("    moment = X[$0, 0]", kNumModules * 4 + 8));
+
+    // Now do the same for the inputs.
+    for (size_t m = 0; m < kNumModules; ++m) {
+      result_py.emplace_back(absl::Substitute("    Is$0 = U[$1, 0]", m, m * 2));
+      result_py.emplace_back(
+          absl::Substitute("    Id$0 = U[$1, 0]", m, m * 2 + 1));
+    }
+
+    result_py.emplace_back("");
+    result_py.emplace_back("    result = casadi.SX.sym('result', 25, 1)");
+    result_py.emplace_back("");
+
+    // And then write out the derivative of each state.
+    for (size_t m = 0; m < kNumModules; ++m) {
+      result_py.emplace_back(
+          absl::Substitute("    result[$0, 0] = omegas$1", m * 4, m));
+      result_py.emplace_back(
+          absl::Substitute("    result[$0, 0] = omegad$1", m * 4 + 1, m));
+
+      result_py.emplace_back(absl::Substitute(
+          "    result[$0, 0] = $1", m * 4 + 2, ccode(*modules_[m].alphas_eqn)));
+      result_py.emplace_back(absl::Substitute(
+          "    result[$0, 0] = $1", m * 4 + 3, ccode(*modules_[m].alphad_eqn)));
+    }
+
+    result_py.emplace_back(
+        absl::Substitute("    result[$0, 0] = omega", kNumModules * 4));
+    result_py.emplace_back(
+        absl::Substitute("    result[$0, 0] = vx", kNumModules * 4 + 1));
+    result_py.emplace_back(
+        absl::Substitute("    result[$0, 0] = vy", kNumModules * 4 + 2));
+
+    result_py.emplace_back(absl::Substitute(
+        "    result[$0, 0] = $1", kNumModules * 4 + 3, ccode(*angular_accel_)));
+    result_py.emplace_back(absl::Substitute("    result[$0, 0] = $1",
+                                            kNumModules * 4 + 4,
+                                            ccode(*accel_.get(0, 0))));
+    result_py.emplace_back(absl::Substitute("    result[$0, 0] = $1",
+                                            kNumModules * 4 + 5,
+                                            ccode(*accel_.get(1, 0))));
+
+    result_py.emplace_back(
+        absl::Substitute("    result[$0, 0] = 0.0", kNumModules * 4 + 6));
+    result_py.emplace_back(
+        absl::Substitute("    result[$0, 0] = 0.0", kNumModules * 4 + 7));
+    result_py.emplace_back(
+        absl::Substitute("    result[$0, 0] = 0.0", kNumModules * 4 + 8));
+
+    result_py.emplace_back("");
+    result_py.emplace_back(
+        "    return casadi.Function('xdot', [X, U], [result])");
+
+    aos::util::WriteStringToFileOrDie(py_path, absl::StrJoin(result_py, "\n"));
+  }
+
  private:
   static constexpr uint8_t kNumModules = 4;
 
@@ -671,11 +776,15 @@
 
   frc971::control_loops::swerve::SwerveSimulation sim;
 
-  if (!FLAGS_cc_output_path.empty() && !FLAGS_h_output_path.empty() &&
-      !FLAGS_py_output_path.empty()) {
+  if (!FLAGS_cc_output_path.empty() && !FLAGS_h_output_path.empty()) {
     sim.Write(FLAGS_cc_output_path, FLAGS_h_output_path);
+  }
+  if (!FLAGS_py_output_path.empty()) {
     sim.WritePy(FLAGS_py_output_path);
   }
+  if (!FLAGS_casadi_py_output_path.empty()) {
+    sim.WriteCasadi(FLAGS_casadi_py_output_path);
+  }
 
   return 0;
 }