Add initial velocity MPC

We actually get to a speed!  Need to sort out why steering
overshoots/undershoots, and test more cases when changing direction.

Change-Id: Icd321c1a79b96281f6226886db840bbaeab85142
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/generate_physics.cc b/frc971/control_loops/swerve/generate_physics.cc
index 1326fd7..2e3fe73 100644
--- a/frc971/control_loops/swerve/generate_physics.cc
+++ b/frc971/control_loops/swerve/generate_physics.cc
@@ -30,8 +30,6 @@
           "Path to write generated cc code to");
 ABSL_FLAG(std::string, h_output_path, "",
           "Path to write generated header code to");
-ABSL_FLAG(std::string, py_output_path, "",
-          "Path to write generated py code to");
 ABSL_FLAG(std::string, casadi_py_output_path, "",
           "Path to write casadi generated py code to");
 ABSL_FLAG(double, caster, 0.01, "Caster in meters for the module.");
@@ -68,19 +66,16 @@
 
 // State per module.
 struct Module {
+  DenseMatrix mounting_location;
+
   RCP<const Symbol> Is;
 
   RCP<const Symbol> Id;
 
   RCP<const Symbol> thetas;
   RCP<const Symbol> omegas;
-  RCP<const Symbol> alphas;
-  RCP<const Basic> alphas_eqn;
 
-  RCP<const Symbol> thetad;
   RCP<const Symbol> omegad;
-  RCP<const Symbol> alphad;
-  RCP<const Basic> alphad_eqn;
 
   DenseMatrix contact_patch_velocity;
   DenseMatrix wheel_ground_velocity;
@@ -89,16 +84,39 @@
   RCP<const Basic> slip_ratio;
 
   RCP<const Basic> Ms;
-  RCP<const Basic> Fwx;
   RCP<const Basic> Fwy;
-  DenseMatrix F;
-  DenseMatrix mounting_location;
 
-  // Acceleration contribution from this module.
-  DenseMatrix accel;
-  RCP<const Basic> angular_accel;
+  struct Full {
+    RCP<const Basic> Fwx;
+    DenseMatrix F;
+
+    RCP<const Basic> torque;
+
+    RCP<const Basic> alphas_eqn;
+    RCP<const Basic> alphad_eqn;
+  } full;
+
+  struct Direct {
+    RCP<const Basic> Fwx;
+    DenseMatrix F;
+
+    RCP<const Basic> torque;
+
+    RCP<const Basic> alphas_eqn;
+  } direct;
 };
 
+DenseMatrix SumMatrices(DenseMatrix a) { return a; }
+
+template <typename... Args>
+DenseMatrix SumMatrices(DenseMatrix a, Args... args) {
+  DenseMatrix result = DenseMatrix(2, 1, {integer(0), integer(0)});
+
+  DenseMatrix b = SumMatrices(args...);
+  add_dense_dense(a, b, result);
+  return result;
+}
+
 class SwerveSimulation {
  public:
   SwerveSimulation() : drive_motor_(KrakenFOC()), steer_motor_(KrakenFOC()) {
@@ -193,6 +211,7 @@
 
     // Now, compute the accelerations due to the disturbance forces.
     DenseMatrix external_accel = DenseMatrix(2, 1, {div(fx, m_), div(fy, m_)});
+    DenseMatrix external_force = DenseMatrix(2, 1, {fx, fy});
 
     // And compute the physics contributions from each module.
     modules_[0] = ModulePhysics(
@@ -213,126 +232,34 @@
                                       div(robot_width_, integer(-2))}));
 
     // And convert them into the overall robot contribution.
-    DenseMatrix temp0 = DenseMatrix(2, 1);
-    DenseMatrix temp1 = DenseMatrix(2, 1);
-    DenseMatrix temp2 = DenseMatrix(2, 1);
-    accel_ = DenseMatrix(2, 1);
+    DenseMatrix net_full_force =
+        SumMatrices(modules_[0].full.F, modules_[1].full.F, modules_[2].full.F,
+                    modules_[3].full.F, external_force);
 
-    add_dense_dense(modules_[0].accel, external_accel, temp0);
-    add_dense_dense(temp0, modules_[1].accel, temp1);
-    add_dense_dense(temp1, modules_[2].accel, temp2);
-    add_dense_dense(temp2, modules_[3].accel, accel_);
+    DenseMatrix net_direct_force =
+        SumMatrices(modules_[0].direct.F, modules_[1].direct.F,
+                    modules_[2].direct.F, modules_[3].direct.F, external_force);
 
-    angular_accel_ =
-        add(div(moment, J_),
-            add(add(modules_[0].angular_accel, modules_[1].angular_accel),
-                add(modules_[2].angular_accel, modules_[3].angular_accel)));
+    full_accel_ = DenseMatrix(2, 1);
+    mul_dense_scalar(net_full_force, pow(m_, minus_one), full_accel_);
 
-    VLOG(1) << "accel(0, 0) = " << ccode(*accel_.get(0, 0));
-    VLOG(1) << "accel(1, 0) = " << ccode(*accel_.get(1, 0));
-    VLOG(1) << "angular_accel = " << ccode(*angular_accel_);
-  }
+    full_angular_accel_ = div(
+        add(moment, add(add(modules_[0].full.torque, modules_[1].full.torque),
+                        add(modules_[2].full.torque, modules_[3].full.torque))),
+        J_);
 
-  // Writes the physics out to the provided .py path.
-  void WritePy(std::string_view py_path) {
-    std::vector<std::string> result_py;
+    direct_accel_ = DenseMatrix(2, 1);
+    mul_dense_scalar(net_direct_force, pow(m_, minus_one), direct_accel_);
 
-    result_py.emplace_back("#!/usr/bin/python3");
-    result_py.emplace_back("");
-    result_py.emplace_back("import numpy");
-    result_py.emplace_back("");
+    direct_angular_accel_ =
+        div(add(moment,
+                add(add(modules_[0].direct.torque, modules_[1].direct.torque),
+                    add(modules_[2].direct.torque, modules_[3].direct.torque))),
+            J_);
 
-    result_py.emplace_back("def swerve_physics(t, X, U_func):");
-    result_py.emplace_back("    def atan2(y, x):");
-    result_py.emplace_back("        if x < 0:");
-    result_py.emplace_back("            return -numpy.atan2(y, x)");
-    result_py.emplace_back("        else:");
-    result_py.emplace_back("            return numpy.atan2(y, x)");
-    result_py.emplace_back("    sin = numpy.sin");
-    result_py.emplace_back("    cos = numpy.cos");
-    result_py.emplace_back("    fabs = numpy.fabs");
-
-    result_py.emplace_back("    result = numpy.empty([25, 1])");
-    result_py.emplace_back("    X = X.reshape(25, 1)");
-    result_py.emplace_back("    U = U_func(X)");
-    result_py.emplace_back("");
-
-    // Start by writing out variables matching each of the symbol names we use
-    // so we don't have to modify the computed equations too much.
-    for (size_t m = 0; m < kNumModules; ++m) {
-      result_py.emplace_back(
-          absl::Substitute("    thetas$0 = X[$1, 0]", m, m * 4));
-      result_py.emplace_back(
-          absl::Substitute("    omegas$0 = X[$1, 0]", m, m * 4 + 2));
-      result_py.emplace_back(
-          absl::Substitute("    omegad$0 = X[$1, 0]", m, m * 4 + 3));
-    }
-
-    result_py.emplace_back(
-        absl::Substitute("    theta = X[$0, 0]", kNumModules * 4 + 2));
-    result_py.emplace_back(
-        absl::Substitute("    vx = X[$0, 0]", kNumModules * 4 + 3));
-    result_py.emplace_back(
-        absl::Substitute("    vy = X[$0, 0]", kNumModules * 4 + 4));
-    result_py.emplace_back(
-        absl::Substitute("    omega = X[$0, 0]", kNumModules * 4 + 5));
-
-    result_py.emplace_back(
-        absl::Substitute("    fx = X[$0, 0]", kNumModules * 4 + 6));
-    result_py.emplace_back(
-        absl::Substitute("    fy = X[$0, 0]", kNumModules * 4 + 7));
-    result_py.emplace_back(
-        absl::Substitute("    moment = X[$0, 0]", kNumModules * 4 + 8));
-
-    // Now do the same for the inputs.
-    for (size_t m = 0; m < kNumModules; ++m) {
-      result_py.emplace_back(absl::Substitute("    Is$0 = U[$1, 0]", m, m * 2));
-      result_py.emplace_back(
-          absl::Substitute("    Id$0 = U[$1, 0]", m, m * 2 + 1));
-    }
-
-    result_py.emplace_back("");
-
-    // And then write out the derivative of each state.
-    for (size_t m = 0; m < kNumModules; ++m) {
-      result_py.emplace_back(
-          absl::Substitute("    result[$0, 0] = omegas$1", m * 4, m));
-      result_py.emplace_back(
-          absl::Substitute("    result[$0, 0] = omegad$1", m * 4 + 1, m));
-
-      result_py.emplace_back(absl::Substitute(
-          "    result[$0, 0] = $1", m * 4 + 2, ccode(*modules_[m].alphas_eqn)));
-      result_py.emplace_back(absl::Substitute(
-          "    result[$0, 0] = $1", m * 4 + 3, ccode(*modules_[m].alphad_eqn)));
-    }
-
-    result_py.emplace_back(
-        absl::Substitute("    result[$0, 0] = vx", kNumModules * 4));
-    result_py.emplace_back(
-        absl::Substitute("    result[$0, 0] = vy", kNumModules * 4 + 1));
-    result_py.emplace_back(
-        absl::Substitute("    result[$0, 0] = omega", kNumModules * 4 + 2));
-
-    result_py.emplace_back(absl::Substitute("    result[$0, 0] = $1",
-                                            kNumModules * 4 + 3,
-                                            ccode(*accel_.get(0, 0))));
-    result_py.emplace_back(absl::Substitute("    result[$0, 0] = $1",
-                                            kNumModules * 4 + 4,
-                                            ccode(*accel_.get(1, 0))));
-    result_py.emplace_back(absl::Substitute(
-        "    result[$0, 0] = $1", kNumModules * 4 + 5, ccode(*angular_accel_)));
-
-    result_py.emplace_back(
-        absl::Substitute("    result[$0, 0] = 0.0", kNumModules * 4 + 6));
-    result_py.emplace_back(
-        absl::Substitute("    result[$0, 0] = 0.0", kNumModules * 4 + 7));
-    result_py.emplace_back(
-        absl::Substitute("    result[$0, 0] = 0.0", kNumModules * 4 + 8));
-
-    result_py.emplace_back("");
-    result_py.emplace_back("    return result.reshape(25,)\n");
-
-    aos::util::WriteStringToFileOrDie(py_path, absl::StrJoin(result_py, "\n"));
+    VLOG(1) << "accel(0, 0) = " << ccode(*full_accel_.get(0, 0));
+    VLOG(1) << "accel(1, 0) = " << ccode(*full_accel_.get(1, 0));
+    VLOG(1) << "angular_accel = " << ccode(*full_angular_accel_);
   }
 
   // Writes the physics out to the provided .cc and .h path.
@@ -439,10 +366,12 @@
       result_cc.emplace_back(
           absl::Substitute("  result($0, 0) = omegad$1;", m * 4 + 1, m));
 
-      result_cc.emplace_back(absl::Substitute(
-          "  result($0, 0) = $1;", m * 4 + 2, ccode(*modules_[m].alphas_eqn)));
-      result_cc.emplace_back(absl::Substitute(
-          "  result($0, 0) = $1;", m * 4 + 3, ccode(*modules_[m].alphad_eqn)));
+      result_cc.emplace_back(
+          absl::Substitute("  result($0, 0) = $1;", m * 4 + 2,
+                           ccode(*modules_[m].full.alphas_eqn)));
+      result_cc.emplace_back(
+          absl::Substitute("  result($0, 0) = $1;", m * 4 + 3,
+                           ccode(*modules_[m].full.alphad_eqn)));
     }
 
     result_cc.emplace_back(
@@ -454,12 +383,13 @@
 
     result_cc.emplace_back(absl::Substitute("  result($0, 0) = $1;",
                                             kNumModules * 4 + 3,
-                                            ccode(*accel_.get(0, 0))));
+                                            ccode(*full_accel_.get(0, 0))));
     result_cc.emplace_back(absl::Substitute("  result($0, 0) = $1;",
                                             kNumModules * 4 + 4,
-                                            ccode(*accel_.get(1, 0))));
-    result_cc.emplace_back(absl::Substitute(
-        "  result($0, 0) = $1;", kNumModules * 4 + 5, ccode(*angular_accel_)));
+                                            ccode(*full_accel_.get(1, 0))));
+    result_cc.emplace_back(absl::Substitute("  result($0, 0) = $1;",
+                                            kNumModules * 4 + 5,
+                                            ccode(*full_angular_accel_)));
 
     result_cc.emplace_back(
         absl::Substitute("  result($0, 0) = 0.0;", kNumModules * 4 + 6));
@@ -522,6 +452,52 @@
     }
   }
 
+  void WriteCasadiVelocityVariables(std::vector<std::string> *result_py) {
+    result_py->emplace_back("    sin = casadi.sin");
+    result_py->emplace_back("    sign = casadi.sign");
+    result_py->emplace_back("    cos = casadi.cos");
+    result_py->emplace_back("    atan2 = casadi.atan2");
+    result_py->emplace_back("    fmax = casadi.fmax");
+    result_py->emplace_back("    fabs = casadi.fabs");
+
+    // Start by writing out variables matching each of the symbol names we use
+    // so we don't have to modify the computed equations too much.
+    for (size_t m = 0; m < kNumModules; ++m) {
+      result_py->emplace_back(
+          absl::Substitute("    thetas$0 = X[$1, 0]", m, m * 2 + 0));
+      result_py->emplace_back(
+          absl::Substitute("    omegas$0 = X[$1, 0]", m, m * 2 + 1));
+    }
+
+    result_py->emplace_back(
+        absl::Substitute("    theta = X[$0, 0]", kNumModules * 2 + 0));
+    result_py->emplace_back(
+        absl::Substitute("    vx = X[$0, 0]", kNumModules * 2 + 1));
+    result_py->emplace_back(
+        absl::Substitute("    vy = X[$0, 0]", kNumModules * 2 + 2));
+    result_py->emplace_back(
+        absl::Substitute("    omega = X[$0, 0]", kNumModules * 2 + 3));
+
+    // result_py->emplace_back(
+    // absl::Substitute("    fx = X[$0, 0]", kNumModules * 3 + 4));
+    // result_py->emplace_back(
+    // absl::Substitute("    fy = X[$0, 0]", kNumModules * 3 + 5));
+    // result_py->emplace_back(
+    // absl::Substitute("    moment = X[$0, 0]", kNumModules * 3 + 6));
+    //
+    result_py->emplace_back("    fx = 0");
+    result_py->emplace_back("    fy = 0");
+    result_py->emplace_back("    moment = 0");
+
+    // Now do the same for the inputs.
+    for (size_t m = 0; m < kNumModules; ++m) {
+      result_py->emplace_back(
+          absl::Substitute("    Is$0 = U[$1, 0]", m, m * 2));
+      result_py->emplace_back(
+          absl::Substitute("    Id$0 = U[$1, 0]", m, m * 2 + 1));
+    }
+  }
+
   // Writes the physics out to the provided .cc and .h path.
   void WriteCasadi(std::string_view py_path) {
     std::vector<std::string> result_py;
@@ -529,12 +505,74 @@
     // Write out the header.
     result_py.emplace_back("#!/usr/bin/python3");
     result_py.emplace_back("");
-    result_py.emplace_back("import casadi");
+    result_py.emplace_back("import casadi, numpy");
     result_py.emplace_back("");
     result_py.emplace_back(absl::Substitute("WHEEL_RADIUS = $0", ccode(*rw_)));
     result_py.emplace_back(
         absl::Substitute("ROBOT_WIDTH = $0", ccode(*robot_width_)));
     result_py.emplace_back(absl::Substitute("CASTER = $0", ccode(*caster_)));
+    result_py.emplace_back("STATE_THETAS0 = 0");
+    result_py.emplace_back("STATE_THETAD0 = 1");
+    result_py.emplace_back("STATE_OMEGAS0 = 2");
+    result_py.emplace_back("STATE_OMEGAD0 = 3");
+    result_py.emplace_back("STATE_THETAS1 = 4");
+    result_py.emplace_back("STATE_THETAD1 = 5");
+    result_py.emplace_back("STATE_OMEGAS1 = 6");
+    result_py.emplace_back("STATE_OMEGAD1 = 7");
+    result_py.emplace_back("STATE_THETAS2 = 8");
+    result_py.emplace_back("STATE_THETAD2 = 9");
+    result_py.emplace_back("STATE_OMEGAS2 = 10");
+    result_py.emplace_back("STATE_OMEGAD2 = 11");
+    result_py.emplace_back("STATE_THETAS3 = 12");
+    result_py.emplace_back("STATE_THETAD3 = 13");
+    result_py.emplace_back("STATE_OMEGAS3 = 14");
+    result_py.emplace_back("STATE_OMEGAD3 = 15");
+    result_py.emplace_back("STATE_X = 16");
+    result_py.emplace_back("STATE_Y = 17");
+    result_py.emplace_back("STATE_THETA = 18");
+    result_py.emplace_back("STATE_VX = 19");
+    result_py.emplace_back("STATE_VY = 20");
+    result_py.emplace_back("STATE_OMEGA = 21");
+    result_py.emplace_back("STATE_FX = 22");
+    result_py.emplace_back("STATE_FY = 23");
+    result_py.emplace_back("STATE_MOMENT = 24");
+    result_py.emplace_back("NUM_STATES = 25");
+    result_py.emplace_back("");
+    result_py.emplace_back("VELOCITY_STATE_THETAS0 = 0");
+    result_py.emplace_back("VELOCITY_STATE_OMEGAS0 = 1");
+    result_py.emplace_back("VELOCITY_STATE_THETAS1 = 2");
+    result_py.emplace_back("VELOCITY_STATE_OMEGAS1 = 3");
+    result_py.emplace_back("VELOCITY_STATE_THETAS2 = 4");
+    result_py.emplace_back("VELOCITY_STATE_OMEGAS2 = 5");
+    result_py.emplace_back("VELOCITY_STATE_THETAS3 = 6");
+    result_py.emplace_back("VELOCITY_STATE_OMEGAS3 = 7");
+    result_py.emplace_back("VELOCITY_STATE_THETA = 8");
+    result_py.emplace_back("VELOCITY_STATE_VX = 9");
+    result_py.emplace_back("VELOCITY_STATE_VY = 10");
+    result_py.emplace_back("VELOCITY_STATE_OMEGA = 11");
+    // result_py.emplace_back("VELOCITY_STATE_FX = 16");
+    // result_py.emplace_back("VELOCITY_STATE_FY = 17");
+    // result_py.emplace_back("VELOCITY_STATE_MOMENT = 18");
+    result_py.emplace_back("NUM_VELOCITY_STATES = 12");
+    result_py.emplace_back("");
+    result_py.emplace_back("def to_velocity_state(X):");
+    result_py.emplace_back("    return numpy.array([");
+    result_py.emplace_back("        [X[STATE_THETAS0, 0]],");
+    result_py.emplace_back("        [X[STATE_OMEGAS0, 0]],");
+    result_py.emplace_back("        [X[STATE_THETAS1, 0]],");
+    result_py.emplace_back("        [X[STATE_OMEGAS1, 0]],");
+    result_py.emplace_back("        [X[STATE_THETAS2, 0]],");
+    result_py.emplace_back("        [X[STATE_OMEGAS2, 0]],");
+    result_py.emplace_back("        [X[STATE_THETAS3, 0]],");
+    result_py.emplace_back("        [X[STATE_OMEGAS3, 0]],");
+    result_py.emplace_back("        [X[STATE_THETA, 0]],");
+    result_py.emplace_back("        [X[STATE_VX, 0]],");
+    result_py.emplace_back("        [X[STATE_VY, 0]],");
+    result_py.emplace_back("        [X[STATE_OMEGA, 0]],");
+    // result_py.emplace_back("        [X[STATE_FX, 0]],");
+    // result_py.emplace_back("        [X[STATE_FY, 0]],");
+    // result_py.emplace_back("        [X[STATE_MOMENT, 0]],");
+    result_py.emplace_back("    ])");
     result_py.emplace_back("");
 
     result_py.emplace_back("# Returns the derivative of our state vector");
@@ -544,7 +582,7 @@
     result_py.emplace_back("#  thetas3, thetad3, omegas3, omegad3,");
     result_py.emplace_back("#  x, y, theta, vx, vy, omega,");
     result_py.emplace_back("#  Fx, Fy, Moment]");
-    result_py.emplace_back("def swerve_physics(X, U):");
+    result_py.emplace_back("def swerve_full_dynamics(X, U):");
     WriteCasadiVariables(&result_py);
 
     result_py.emplace_back("");
@@ -558,10 +596,12 @@
       result_py.emplace_back(
           absl::Substitute("    result[$0, 0] = omegad$1", m * 4 + 1, m));
 
-      result_py.emplace_back(absl::Substitute(
-          "    result[$0, 0] = $1", m * 4 + 2, ccode(*modules_[m].alphas_eqn)));
-      result_py.emplace_back(absl::Substitute(
-          "    result[$0, 0] = $1", m * 4 + 3, ccode(*modules_[m].alphad_eqn)));
+      result_py.emplace_back(
+          absl::Substitute("    result[$0, 0] = $1", m * 4 + 2,
+                           ccode(*modules_[m].full.alphas_eqn)));
+      result_py.emplace_back(
+          absl::Substitute("    result[$0, 0] = $1", m * 4 + 3,
+                           ccode(*modules_[m].full.alphad_eqn)));
     }
 
     result_py.emplace_back(
@@ -573,12 +613,13 @@
 
     result_py.emplace_back(absl::Substitute("    result[$0, 0] = $1",
                                             kNumModules * 4 + 3,
-                                            ccode(*accel_.get(0, 0))));
+                                            ccode(*full_accel_.get(0, 0))));
     result_py.emplace_back(absl::Substitute("    result[$0, 0] = $1",
                                             kNumModules * 4 + 4,
-                                            ccode(*accel_.get(1, 0))));
-    result_py.emplace_back(absl::Substitute(
-        "    result[$0, 0] = $1", kNumModules * 4 + 5, ccode(*angular_accel_)));
+                                            ccode(*full_accel_.get(1, 0))));
+    result_py.emplace_back(absl::Substitute("    result[$0, 0] = $1",
+                                            kNumModules * 4 + 5,
+                                            ccode(*full_angular_accel_)));
 
     result_py.emplace_back(
         absl::Substitute("    result[$0, 0] = 0.0", kNumModules * 4 + 6));
@@ -591,6 +632,54 @@
     result_py.emplace_back(
         "    return casadi.Function('xdot', [X, U], [result])");
 
+    result_py.emplace_back("");
+
+    result_py.emplace_back("# Returns the derivative of our state vector");
+    result_py.emplace_back("# [thetas0, omegas0,");
+    result_py.emplace_back("#  thetas1, omegas1,");
+    result_py.emplace_back("#  thetas2, omegas2,");
+    result_py.emplace_back("#  thetas3, omegas3,");
+    result_py.emplace_back("#  theta, vx, vy, omega]");
+    result_py.emplace_back("def velocity_swerve_physics(X, U):");
+    WriteCasadiVelocityVariables(&result_py);
+
+    result_py.emplace_back("");
+    result_py.emplace_back(
+        "    result = casadi.SX.sym('result', NUM_VELOCITY_STATES, 1)");
+    result_py.emplace_back("");
+
+    // And then write out the derivative of each state.
+    for (size_t m = 0; m < kNumModules; ++m) {
+      result_py.emplace_back(
+          absl::Substitute("    result[$0, 0] = omegas$1", m * 2 + 0, m));
+      result_py.emplace_back(
+          absl::Substitute("    result[$0, 0] = $1", m * 2 + 1,
+                           ccode(*modules_[m].direct.alphas_eqn)));
+    }
+    result_py.emplace_back(
+        absl::Substitute("    result[$0, 0] = omega", kNumModules * 2 + 0));
+
+    result_py.emplace_back(absl::Substitute("    result[$0, 0] = $1",
+                                            kNumModules * 2 + 1,
+                                            ccode(*direct_accel_.get(0, 0))));
+    result_py.emplace_back(absl::Substitute("    result[$0, 0] = $1",
+                                            kNumModules * 2 + 2,
+                                            ccode(*direct_accel_.get(1, 0))));
+    result_py.emplace_back(absl::Substitute("    result[$0, 0] = $1",
+                                            kNumModules * 2 + 3,
+                                            ccode(*direct_angular_accel_)));
+
+    // result_py.emplace_back(
+    // absl::Substitute("    result[$0, 0] = 0.0", kNumModules * 3 + 4));
+    // result_py.emplace_back(
+    // absl::Substitute("    result[$0, 0] = 0.0", kNumModules * 3 + 5));
+    // result_py.emplace_back(
+    // absl::Substitute("    result[$0, 0] = 0.0", kNumModules * 3 + 6));
+
+    result_py.emplace_back("");
+    result_py.emplace_back(
+        "    return casadi.Function('xdot', [X, U], [result])");
+
     DefineVector2dFunction(
         "contact_patch_velocity",
         "# Returns the velocity of the wheel in global coordinates.",
@@ -624,20 +713,22 @@
     DefineScalarFunction(
         "module_angular_accel",
         "Returns the angular acceleration of the robot due to the ith wheel",
-        [](const Module &m) { return ccode(*m.angular_accel); }, &result_py);
+        [this](const Module &m) { return ccode(*div(m.full.torque, Js_)); },
+        &result_py);
 
     DefineVector2dFunction(
         "wheel_force",
         "Returns the force on the wheel in steer module coordinates",
         [](const Module &m, int dimension) {
-          return ccode(*std::vector<RCP<const Basic>>{m.Fwx, m.Fwy}[dimension]);
+          return ccode(
+              *std::vector<RCP<const Basic>>{m.full.Fwx, m.Fwy}[dimension]);
         },
         &result_py);
 
     DefineVector2dFunction(
         "F", "Returns the force on the wheel in absolute coordinates",
         [](const Module &m, int dimension) {
-          return ccode(*m.F.get(dimension, 0));
+          return ccode(*m.full.F.get(dimension, 0));
         },
         &result_py);
 
@@ -711,6 +802,16 @@
  private:
   static constexpr uint8_t kNumModules = 4;
 
+  RCP<const Basic> SteerAccel(RCP<const Basic> Fwx, RCP<const Basic> Ms,
+                              RCP<const Basic> Is) {
+    RCP<const Basic> lhms =
+        mul(add(neg(wb_), mul(add(rs_, rp_), sub(integer(1), div(rb1_, rp_)))),
+            mul(div(rw_, rb2_), neg(Fwx)));
+    RCP<const Basic> lhs = add(add(Ms, div(mul(Kts_, Is), Gs_)), lhms);
+    RCP<const Basic> rhs = add(Jsm_, div(div(Js_, Gs_), Gs_));
+    return simplify(div(lhs, rhs));
+  }
+
   Module ModulePhysics(const int m, DenseMatrix mounting_location) {
     VLOG(1) << "Solving module " << m;
 
@@ -726,11 +827,10 @@
 
     result.thetas = symbol(absl::StrFormat("thetas%u", m));
     result.omegas = symbol(absl::StrFormat("omegas%u", m));
-    result.alphas = symbol(absl::StrFormat("alphas%u", m));
+    RCP<const Symbol> alphas = symbol(absl::StrFormat("alphas%u", m));
 
-    result.thetad = symbol(absl::StrFormat("thetad%u", m));
     result.omegad = symbol(absl::StrFormat("omegad%u", m));
-    result.alphad = symbol(absl::StrFormat("alphad%u", m));
+    RCP<const Symbol> alphad = symbol(absl::StrFormat("alphad%u", m));
 
     // Velocity of the module in field coordinates
     DenseMatrix robot_velocity = DenseMatrix(2, 1, {vx_, vy_});
@@ -790,7 +890,7 @@
     VLOG(1);
     VLOG(1) << "Slip ratio " << result.slip_ratio->__str__();
 
-    result.Fwx = simplify(mul(Cx_, result.slip_ratio));
+    result.full.Fwx = simplify(mul(Cx_, result.slip_ratio));
     result.Fwy = simplify(mul(Cy_, result.slip_angle));
 
     // The self-aligning moment needs to flip when the module flips direction.
@@ -802,35 +902,34 @@
     VLOG(1);
     VLOG(1) << "Ms " << result.Ms->__str__();
     VLOG(1);
-    VLOG(1) << "Fwx " << result.Fwx->__str__();
+    VLOG(1) << "full.Fwx " << result.full.Fwx->__str__();
     VLOG(1);
     VLOG(1) << "Fwy " << result.Fwy->__str__();
 
-    // alphas = ...
-    RCP<const Basic> lhms =
-        mul(add(neg(wb_), mul(add(rs_, rp_), sub(integer(1), div(rb1_, rp_)))),
-            mul(div(rw_, rb2_), neg(result.Fwx)));
-    RCP<const Basic> lhs =
-        add(add(result.Ms, div(mul(Kts_, result.Is), Gs_)), lhms);
-    RCP<const Basic> rhs = add(Jsm_, div(div(Js_, Gs_), Gs_));
-    RCP<const Basic> accel_steer_eqn = simplify(div(lhs, rhs));
+    // -K_td * Id / Gd + Fwx * rw = 0
+    // Fwx = K_td * Id / Gd / rw
+    result.direct.Fwx = mul(Ktd_, div(result.Id, mul(Gd_, rw_)));
+
+    result.direct.alphas_eqn =
+        SteerAccel(result.direct.Fwx, result.Ms, result.Is);
+
+    // d/dt omegas = ...
+    result.full.alphas_eqn = SteerAccel(result.full.Fwx, result.Ms, result.Is);
 
     VLOG(1);
-    VLOG(1) << result.alphas->__str__() << " = " << accel_steer_eqn->__str__();
+    VLOG(1) << alphas->__str__() << " = " << result.full.alphas_eqn->__str__();
 
-    lhs = sub(mul(sub(div(add(rp_, rs_), rp_), integer(1)), result.omegas),
-              mul(Gd1_, mul(Gd2_, omegamd)));
-    RCP<const Basic> dplanitary_eqn = sub(mul(Gd3_, lhs), result.omegad);
+    RCP<const Basic> lhs =
+        sub(mul(sub(div(add(rp_, rs_), rp_), integer(1)), alphas),
+            mul(Gd1_, mul(Gd2_, alphamd)));
+    RCP<const Basic> ddplanitary_eqn = sub(mul(Gd3_, lhs), alphad);
 
-    lhs = sub(mul(sub(div(add(rp_, rs_), rp_), integer(1)), result.alphas),
-              mul(Gd1_, mul(Gd2_, alphamd)));
-    RCP<const Basic> ddplanitary_eqn = sub(mul(Gd3_, lhs), result.alphad);
+    RCP<const Basic> full_drive_eqn =
+        sub(add(mul(neg(Jdm_), div(alphamd, Gd_)),
+                mul(Ktd_, div(neg(result.Id), Gd_))),
+            mul(neg(result.full.Fwx), rw_));
 
-    RCP<const Basic> drive_eqn = sub(add(mul(neg(Jdm_), div(alphamd, Gd_)),
-                                         mul(Ktd_, div(neg(result.Id), Gd_))),
-                                     mul(neg(result.Fwx), rw_));
-
-    VLOG(1) << "drive_eqn: " << drive_eqn->__str__();
+    VLOG(1) << "full_drive_eqn: " << full_drive_eqn->__str__();
 
     // Substitute in ddplanitary_eqn so we get rid of alphamd
     map_basic_basic map;
@@ -838,39 +937,38 @@
     RCP<const Set> solve_solution = solve(ddplanitary_eqn, alphamd, reals);
     map[alphamd] = solve_solution->get_args()[1]->get_args()[0];
     VLOG(1) << "temp: " << solve_solution->__str__();
-    RCP<const Basic> drive_eqn_subs = drive_eqn->subs(map);
+    RCP<const Basic> drive_eqn_subs = full_drive_eqn->subs(map);
 
     map.clear();
-    map[result.alphas] = accel_steer_eqn;
+    map[alphas] = result.full.alphas_eqn;
     RCP<const Basic> drive_eqn_subs2 = drive_eqn_subs->subs(map);
     RCP<const Basic> drive_eqn_subs3 = simplify(drive_eqn_subs2);
-    VLOG(1) << "drive_eqn simplified: " << drive_eqn_subs3->__str__();
+    VLOG(1) << "full_drive_eqn simplified: " << drive_eqn_subs3->__str__();
 
-    solve_solution = solve(drive_eqn_subs3, result.alphad, reals);
+    solve_solution = solve(drive_eqn_subs3, alphad, reals);
 
-    RCP<const Basic> drive_accel =
+    result.full.alphad_eqn =
         simplify(solve_solution->get_args()[1]->get_args()[0]);
-    VLOG(1) << "drive_accel: " << drive_accel->__str__();
+    VLOG(1) << "drive_accel: " << result.full.alphad_eqn->__str__();
 
     // Compute the resulting force from the module.
-    result.F = DenseMatrix(2, 1);
+    result.full.F = DenseMatrix(2, 1);
     mul_dense_dense(R(add(theta_, result.thetas)),
-                    DenseMatrix(2, 1, {result.Fwx, result.Fwy}), result.F);
+                    DenseMatrix(2, 1, {result.full.Fwx, result.Fwy}),
+                    result.full.F);
+    result.full.torque = force_cross(result.mounting_location, result.full.F);
 
-    RCP<const Basic> torque = force_cross(result.mounting_location, result.F);
-    result.accel = DenseMatrix(2, 1);
-    mul_dense_scalar(result.F, pow(m_, minus_one), result.accel);
-    result.angular_accel = div(torque, J_);
-    VLOG(1);
-    VLOG(1) << "angular_accel = " << result.angular_accel->__str__();
+    result.direct.F = DenseMatrix(2, 1);
+    mul_dense_dense(R(add(theta_, result.thetas)),
+                    DenseMatrix(2, 1, {result.direct.Fwx, result.Fwy}),
+                    result.direct.F);
+    result.direct.torque =
+        force_cross(result.mounting_location, result.direct.F);
 
     VLOG(1);
-    VLOG(1) << "accel(0, 0) = " << result.accel.get(0, 0)->__str__();
-    VLOG(1);
-    VLOG(1) << "accel(1, 0) = " << result.accel.get(1, 0)->__str__();
+    VLOG(1) << "full torque = " << result.full.torque->__str__();
+    VLOG(1) << "direct torque = " << result.full.torque->__str__();
 
-    result.alphad_eqn = drive_accel;
-    result.alphas_eqn = accel_steer_eqn;
     return result;
   }
 
@@ -938,8 +1036,10 @@
 
   std::array<Module, kNumModules> modules_;
 
-  DenseMatrix accel_;
-  RCP<const Basic> angular_accel_;
+  DenseMatrix full_accel_;
+  RCP<const Basic> full_angular_accel_;
+  DenseMatrix direct_accel_;
+  RCP<const Basic> direct_angular_accel_;
 };
 
 }  // namespace frc971::control_loops::swerve
@@ -954,9 +1054,6 @@
     sim.Write(absl::GetFlag(FLAGS_cc_output_path),
               absl::GetFlag(FLAGS_h_output_path));
   }
-  if (!absl::GetFlag(FLAGS_py_output_path).empty()) {
-    sim.WritePy(absl::GetFlag(FLAGS_py_output_path));
-  }
   if (!absl::GetFlag(FLAGS_casadi_py_output_path).empty()) {
     sim.WriteCasadi(absl::GetFlag(FLAGS_casadi_py_output_path));
   }