Add initial velocity MPC

We actually get to a speed!  Need to sort out why steering
overshoots/undershoots, and test more cases when changing direction.

Change-Id: Icd321c1a79b96281f6226886db840bbaeab85142
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/BUILD b/frc971/control_loops/swerve/BUILD
index 3755ab4..e6a3798 100644
--- a/frc971/control_loops/swerve/BUILD
+++ b/frc971/control_loops/swerve/BUILD
@@ -131,13 +131,11 @@
         "dynamics.cc",
         "dynamics.h",
         "dynamics.py",
-        "numpy_dynamics.py",
     ],
     args = [
         "--output_base=$(BINDIR)/",
         "--cc_output_path=$(location :dynamics.cc)",
         "--h_output_path=$(location :dynamics.h)",
-        "--py_output_path=$(location :numpy_dynamics.py)",
         "--casadi_py_output_path=$(location :dynamics.py)",
     ],
     tool = ":generate_physics",
@@ -178,18 +176,15 @@
     ],
 )
 
-py_binary(
-    name = "dynamics_sim",
+py_library(
+    name = "dynamics",
     srcs = [
-        "dynamics_sim.py",
-        "numpy_dynamics.py",
+        "bigcaster_dynamics.py",
+        "dynamics.py",
+        "nocaster_dynamics.py",
     ],
     deps = [
-        "//frc971/control_loops/python:controls",
-        "@pip//matplotlib",
-        "@pip//numpy",
-        "@pip//pygobject",
-        "@pip//scipy",
+        "@pip//casadi",
     ],
 )
 
@@ -200,21 +195,25 @@
     ],
     target_compatible_with = ["@platforms//cpu:x86_64"],
     deps = [
+        ":dynamics",
         ":physics_test_utils",
         "@pip//casadi",
         "@pip//numpy",
     ],
 )
 
-py_library(
-    name = "dynamics",
+py_binary(
+    name = "casadi_velocity_mpc",
     srcs = [
-        "bigcaster_dynamics.py",
-        "dynamics.py",
-        "nocaster_dynamics.py",
+        "casadi_velocity_mpc.py",
     ],
     deps = [
+        ":dynamics",
         "@pip//casadi",
+        "@pip//matplotlib",
+        "@pip//numpy",
+        "@pip//pygobject",
+        "@pip//scipy",
     ],
 )
 
diff --git a/frc971/control_loops/swerve/Makefile b/frc971/control_loops/swerve/Makefile
index c9a61be..611ab02 100644
--- a/frc971/control_loops/swerve/Makefile
+++ b/frc971/control_loops/swerve/Makefile
@@ -1,5 +1,5 @@
 all:
-	pdflatex swerve_notes.tex
+	pdflatex -halt-on-error swerve_notes.tex
 	bibtex swerve_notes
-	pdflatex swerve_notes.tex
-	pdflatex swerve_notes.tex
\ No newline at end of file
+	pdflatex -halt-on-error swerve_notes.tex
+	pdflatex -halt-on-error swerve_notes.tex
diff --git a/frc971/control_loops/swerve/casadi_velocity_mpc.py b/frc971/control_loops/swerve/casadi_velocity_mpc.py
new file mode 100644
index 0000000..5335a10
--- /dev/null
+++ b/frc971/control_loops/swerve/casadi_velocity_mpc.py
@@ -0,0 +1,270 @@
+#!/usr/bin/python3
+
+from frc971.control_loops.swerve import dynamics
+import matplotlib.pyplot as pyplot
+from matplotlib import pylab
+import numpy
+import time
+import scipy
+import casadi
+import os, sys
+
+
+class MPC(object):
+
+    def __init__(self):
+        self.fdot = dynamics.swerve_full_dynamics(
+            casadi.SX.sym("X", dynamics.NUM_STATES, 1),
+            casadi.SX.sym("U", 8, 1))
+        self.velocity_fdot = dynamics.velocity_swerve_physics(
+            casadi.SX.sym("X", dynamics.NUM_VELOCITY_STATES, 1),
+            casadi.SX.sym("U", 8, 1))
+
+        self.wrapped_swerve_physics = lambda X, U: numpy.array(self.fdot(X, U))
+
+        self.dt = 0.005
+
+        # TODO(austin): Do we need a disturbance torque per module to account for friction?
+        # Do it only in the observer/post?
+
+        self.next_X = self.make_physics()
+        self.force = [
+            dynamics.F(i, casadi.SX.sym("X", 25, 1), casadi.SX.sym("U", 8, 1))
+            for i in range(4)
+        ]
+        self.slip_angle = [
+            dynamics.slip_angle(i, casadi.SX.sym("X", 25, 1),
+                                casadi.SX.sym("U", 8, 1)) for i in range(4)
+        ]
+
+        self.N = 2
+
+        # TODO(austin): Can we approximate sin/cos/atan for the initial operating point to solve faster if we need it?
+
+        # Start with an empty nonlinear program.
+        self.w = []
+        self.w0 = []
+        self.lbw = []
+        self.ubw = []
+        J = 0
+        self.g = []
+        self.lbg = []
+        self.ubg = []
+
+        self.X0 = casadi.MX.sym('X0', dynamics.NUM_VELOCITY_STATES)
+
+        # We care about the linear and angular velocities only.
+        self.R = casadi.MX.sym('R', 3)
+
+        Xk = self.X0
+        # Instead of an equality constraint on the goal, what about:
+        # self.w += [Xk]
+        # lbw += [0, 1]
+        # ubw += [0, 1]
+        # w0 += [0, 1]
+
+        for k in range(self.N):
+            Uk = casadi.MX.sym('U_' + str(k), 8)
+            # TODO(austin): Add a line to g here for a 12 volt battery and the breakers?
+            self.w += [Uk]
+            self.lbw += [-100] * 8
+            self.ubw += [100] * 8
+            self.w0 += [0] * 8
+
+            # Integrate till the end of the interval
+            Fk = self.next_X(x0=Xk, u=Uk)
+            Xk_end = Fk['xf']
+
+            # TODO(austin): tune cost fn?
+            # Do we want to penalize slipping tires?
+            # Do we want to penalize powers unevenly applied across the motors?
+            # Need to do some simulations to see what works well.
+
+            J += 1000 * (self.R[0] - Xk_end[dynamics.VELOCITY_STATE_VX])**2.0
+            J += 1000 * (self.R[1] - Xk_end[dynamics.VELOCITY_STATE_VY])**2.0
+            J += 1000 * (self.R[2] -
+                         Xk_end[dynamics.VELOCITY_STATE_OMEGA])**2.0
+
+            kSteerPositionGain = 0
+            kSteerVelocityGain = 0.05
+            J += kSteerPositionGain * (
+                Xk_end[dynamics.VELOCITY_STATE_THETAS0])**2.0
+            J += kSteerVelocityGain * (
+                Xk_end[dynamics.VELOCITY_STATE_OMEGAS0])**2.0
+
+            J += kSteerPositionGain * (
+                Xk_end[dynamics.VELOCITY_STATE_THETAS1])**2.0
+            J += kSteerVelocityGain * (
+                Xk_end[dynamics.VELOCITY_STATE_OMEGAS1])**2.0
+
+            J += kSteerPositionGain * (
+                Xk_end[dynamics.VELOCITY_STATE_THETAS2])**2.0
+            J += kSteerVelocityGain * (
+                Xk_end[dynamics.VELOCITY_STATE_OMEGAS2])**2.0
+
+            J += kSteerPositionGain * (
+                Xk_end[dynamics.VELOCITY_STATE_THETAS3])**2.0
+            J += kSteerVelocityGain * (
+                Xk_end[dynamics.VELOCITY_STATE_OMEGAS3])**2.0
+
+            #for i in range(4):
+            #    sa = self.slip_angle[i](Xk_end, Uk)
+            #    J += 10 * sa * sa
+
+            for i in range(4):
+                # Steer
+                J += Uk[2 * i + 0] * Uk[2 * i + 0] / 100000.0
+                # Drive
+                J += Uk[2 * i + 1] * Uk[2 * i + 1] / 1000.0
+
+            # New NLP variable for state at end of interval
+            Xk = casadi.MX.sym('X_' + str(k + 1), dynamics.NUM_VELOCITY_STATES)
+            self.w += [Xk]
+            self.ubw += [casadi.inf] * dynamics.NUM_VELOCITY_STATES
+            self.lbw += [-casadi.inf] * dynamics.NUM_VELOCITY_STATES
+
+            # Add equality constraint by using lbg <= g <= ubg
+            self.g += [Xk_end - Xk]
+            self.lbg += [0] * dynamics.NUM_VELOCITY_STATES
+            self.ubg += [0] * dynamics.NUM_VELOCITY_STATES
+
+        prob = {
+            'f': J,
+            # lbx <= x <= ubx
+            'x': casadi.vertcat(*self.w),
+            # lbg <= g(x, p) <= ubg
+            'g': casadi.vertcat(*self.g),
+            # Input parameters (initial position + goal)
+            'p': casadi.vertcat(self.X0, self.R),
+        }
+        print('J', J)
+        self.solver = casadi.nlpsol('solver', 'ipopt', prob)
+
+    def make_physics(self):
+        X0 = casadi.MX.sym('X0', dynamics.NUM_VELOCITY_STATES)
+        U = casadi.MX.sym('U', 8)
+
+        X = X0
+        M = 4  # RK4 steps per interval
+        DT = self.dt / M
+
+        for j in range(M):
+            k1 = self.velocity_fdot(X, U)
+            k2 = self.velocity_fdot(X + DT / 2 * k1, U)
+            k3 = self.velocity_fdot(X + DT / 2 * k2, U)
+            k4 = self.velocity_fdot(X + DT * k3, U)
+            X = X + DT / 6 / M * (k1 + 2 * k2 + 2 * k3 + k4)
+
+        return casadi.Function("next_X", [X0, U], [X], ['x0', 'u'], ['xf'])
+
+    def solve(self, p, seed=None):
+        w0 = []
+        for k in range(self.N):
+            w0 += [0, 0] * 4
+
+            w_step = p[:dynamics.NUM_VELOCITY_STATES, 0]
+            w0 += list(w_step)
+
+        return self.solver(x0=w0,
+                           lbx=self.lbw,
+                           ubx=self.ubw,
+                           lbg=self.lbg,
+                           ubg=self.ubg,
+                           p=casadi.DM(p))
+
+    def unpack_u(self, sol, i):
+        return sol['x'].full().flatten()[
+            (8 + 1 * dynamics.NUM_VELOCITY_STATES) *
+            i:(8 + 1 * dynamics.NUM_VELOCITY_STATES) * i + 8]
+
+    def unpack_x(self, sol, i):
+        return sol['x'].full().flatten()[
+            (8 + 1 * dynamics.NUM_VELOCITY_STATES) * (i - 1) +
+            8:(8 + 1 * dynamics.NUM_VELOCITY_STATES) * (i - 1) + 8 +
+            dynamics.NUM_VELOCITY_STATES]
+
+
+mpc = MPC()
+
+R_goal = numpy.zeros((3, 1))
+R_goal[0, 0] = 1.0
+R_goal[1, 0] = 1.0
+
+X_initial = numpy.zeros((25, 1))
+# All the wheels are spinning at the speed needed to hit 1 m/s
+X_initial[3, 0] = 0.0
+X_initial[7, 0] = 0.0
+X_initial[11, 0] = 0.0
+X_initial[15, 0] = 0.0
+
+# Robot is moving at 0 m/s
+X_initial[19, 0] = 0.01
+X_initial[20, 0] = 0.0
+# No angular velocity
+X_initial[21, 0] = 0.0
+
+iterations = 200
+
+X_plot = numpy.zeros((25, iterations))
+U_plot = numpy.zeros((8, iterations))
+t = []
+
+X = X_initial.copy()
+
+pyplot.ion()
+
+fig0, axs0 = pylab.subplots(2)
+fig1, axs1 = pylab.subplots(2)
+last_time = time.time()
+
+for i in range(iterations):
+    t.append(i * mpc.dt)
+    print("Current X at", i * mpc.dt, X.transpose())
+    print("Goal R at", i * mpc.dt, R_goal)
+    sol = mpc.solve(
+        # TODO(austin): Is this better or worse than constraints on the initial state for convergence?
+        p=numpy.vstack((dynamics.to_velocity_state(X), R_goal)))
+    X_plot[:, i] = X[:, 0]
+
+    U = mpc.unpack_u(sol, 0)
+    U_plot[:, i] = U
+
+    print('x(0):', X.transpose())
+    for j in range(mpc.N):
+        print(f'u({j}): ', mpc.unpack_u(sol, j))
+        print(f'x({j+1}): ', mpc.unpack_x(sol, j + 1))
+
+    result = scipy.integrate.solve_ivp(
+        lambda t, x: mpc.wrapped_swerve_physics(x, U).flatten(), [0, mpc.dt],
+        X.flatten())
+    X[:, 0] = result.y[:, -1]
+
+    if time.time() > last_time + 2 or i == iterations - 1:
+        axs0[0].clear()
+        axs0[1].clear()
+
+        axs0[0].plot(t, X_plot[dynamics.STATE_VX, 0:i + 1], label="vx")
+        axs0[0].plot(t, X_plot[dynamics.STATE_VY, 0:i + 1], label="vy")
+        axs0[0].legend()
+
+        axs0[1].plot(t, U_plot[0, 0:i + 1], label="Is0")
+        axs0[1].plot(t, U_plot[1, 0:i + 1], label="Id0")
+        axs0[1].legend()
+
+        axs1[0].clear()
+        axs1[1].clear()
+
+        axs1[0].plot(t, X_plot[0, 0:i + 1], label='steer0')
+        axs1[0].plot(t, X_plot[4, 0:i + 1], label='steer1')
+        axs1[0].plot(t, X_plot[8, 0:i + 1], label='steer2')
+        axs1[0].plot(t, X_plot[12, 0:i + 1], label='steer3')
+        axs1[0].legend()
+        axs1[1].plot(t, X_plot[2, 0:i + 1], label='steer_velocity0')
+        axs1[1].plot(t, X_plot[6, 0:i + 1], label='steer_velocity1')
+        axs1[1].plot(t, X_plot[10, 0:i + 1], label='steer_velocity2')
+        axs1[1].plot(t, X_plot[14, 0:i + 1], label='steer_velocity3')
+        axs1[1].legend()
+        pyplot.pause(0.0001)
+        last_time = time.time()
+
+pyplot.pause(-1)
diff --git a/frc971/control_loops/swerve/dynamics_sim.py b/frc971/control_loops/swerve/dynamics_sim.py
deleted file mode 100644
index a50b62b..0000000
--- a/frc971/control_loops/swerve/dynamics_sim.py
+++ /dev/null
@@ -1,83 +0,0 @@
-#!/usr/bin/python3
-
-import numpy
-import math
-import scipy.integrate
-
-from matplotlib import pylab
-import sys
-import gflags
-import glog
-
-from frc971.control_loops.swerve.numpy_dynamics import swerve_physics
-
-FLAGS = gflags.FLAGS
-
-
-def u_func(X):
-    result = numpy.zeros([8, 1])
-    # result[1, 0] = 80.0
-    # result[3, 0] = 80.0
-    # result[5, 0] = -80.0
-    # result[7, 0] = -80.0
-    return result
-
-
-def main(argv):
-    x_initial = numpy.zeros([25, 1])
-    # x_initial[0] = -math.pi / 2.0
-    x_initial[3] = 1.0 / (2.0 * 0.0254)
-    # x_initial[4] = math.pi / 2.0
-    x_initial[7] = 1.0 / (2.0 * 0.0254)
-    # x_initial[8] = -math.pi / 2.0
-    x_initial[11] = 1.0 / (2.0 * 0.0254)
-    # x_initial[12] = math.pi / 2.0
-    x_initial[15] = 1.0 / (2.0 * 0.0254)
-    x_initial[19] = 2.0
-    result = scipy.integrate.solve_ivp(swerve_physics, (0, 2.0),
-                                       x_initial.reshape(25, ),
-                                       max_step=0.01,
-                                       args=(u_func, ))
-
-    cm = pylab.get_cmap('gist_rainbow')
-    fig = pylab.figure()
-    ax = fig.add_subplot(111)
-    ax.set_prop_cycle(color=[cm(1. * i / 25) for i in range(25)])
-    ax.plot(result.t, result.y[0, :], label="thetas0", linewidth=7.0)
-    ax.plot(result.t, result.y[1, :], label="thetad0", linewidth=7.0)
-    ax.plot(result.t, result.y[2, :], label="omegas0", linewidth=7.0)
-    ax.plot(result.t, result.y[3, :], label="omegad0", linewidth=7.0)
-    ax.plot(result.t, result.y[4, :], label="thetas1", linewidth=7.0)
-    ax.plot(result.t, result.y[5, :], label="thetad1", linewidth=7.0)
-    ax.plot(result.t, result.y[6, :], label="omegas1", linewidth=7.0)
-    ax.plot(result.t, result.y[7, :], label="omegad1", linewidth=7.0)
-    ax.plot(result.t, result.y[8, :], label="thetas2", linewidth=7.0)
-    ax.plot(result.t, result.y[9, :], label="thetad2", linewidth=7.0)
-    ax.plot(result.t, result.y[10, :], label="omegas2", linewidth=7.0)
-    ax.plot(result.t, result.y[11, :], label="omegad2", linewidth=7.0)
-    ax.plot(result.t, result.y[12, :], label="thetas3", linewidth=7.0)
-    ax.plot(result.t, result.y[13, :], label="thetad3", linewidth=7.0)
-    ax.plot(result.t, result.y[14, :], label="omegas3", linewidth=7.0)
-    ax.plot(result.t, result.y[15, :], label="omegad3", linewidth=7.0)
-    ax.plot(result.t, result.y[16, :], label="x", linewidth=7.0)
-    ax.plot(result.t, result.y[17, :], label="y", linewidth=7.0)
-    ax.plot(result.t, result.y[18, :], label="theta", linewidth=7.0)
-    ax.plot(result.t, result.y[19, :], label="vx", linewidth=7.0)
-    ax.plot(result.t, result.y[20, :], label="vy", linewidth=7.0)
-    ax.plot(result.t, result.y[21, :], label="omega", linewidth=7.0)
-    ax.plot(result.t, result.y[22, :], label="Fx", linewidth=7.0)
-    ax.plot(result.t, result.y[23, :], label="Fy", linewidth=7.0)
-    ax.plot(result.t, result.y[24, :], label="Moment", linewidth=7.0)
-    numpy.set_printoptions(threshold=numpy.inf)
-    print(result.t)
-    print(result.y)
-    pylab.legend()
-    pylab.show()
-
-    return 0
-
-
-if __name__ == '__main__':
-    argv = FLAGS(sys.argv)
-    glog.init()
-    sys.exit(main(argv))
diff --git a/frc971/control_loops/swerve/generate_physics.cc b/frc971/control_loops/swerve/generate_physics.cc
index 1326fd7..2e3fe73 100644
--- a/frc971/control_loops/swerve/generate_physics.cc
+++ b/frc971/control_loops/swerve/generate_physics.cc
@@ -30,8 +30,6 @@
           "Path to write generated cc code to");
 ABSL_FLAG(std::string, h_output_path, "",
           "Path to write generated header code to");
-ABSL_FLAG(std::string, py_output_path, "",
-          "Path to write generated py code to");
 ABSL_FLAG(std::string, casadi_py_output_path, "",
           "Path to write casadi generated py code to");
 ABSL_FLAG(double, caster, 0.01, "Caster in meters for the module.");
@@ -68,19 +66,16 @@
 
 // State per module.
 struct Module {
+  DenseMatrix mounting_location;
+
   RCP<const Symbol> Is;
 
   RCP<const Symbol> Id;
 
   RCP<const Symbol> thetas;
   RCP<const Symbol> omegas;
-  RCP<const Symbol> alphas;
-  RCP<const Basic> alphas_eqn;
 
-  RCP<const Symbol> thetad;
   RCP<const Symbol> omegad;
-  RCP<const Symbol> alphad;
-  RCP<const Basic> alphad_eqn;
 
   DenseMatrix contact_patch_velocity;
   DenseMatrix wheel_ground_velocity;
@@ -89,16 +84,39 @@
   RCP<const Basic> slip_ratio;
 
   RCP<const Basic> Ms;
-  RCP<const Basic> Fwx;
   RCP<const Basic> Fwy;
-  DenseMatrix F;
-  DenseMatrix mounting_location;
 
-  // Acceleration contribution from this module.
-  DenseMatrix accel;
-  RCP<const Basic> angular_accel;
+  struct Full {
+    RCP<const Basic> Fwx;
+    DenseMatrix F;
+
+    RCP<const Basic> torque;
+
+    RCP<const Basic> alphas_eqn;
+    RCP<const Basic> alphad_eqn;
+  } full;
+
+  struct Direct {
+    RCP<const Basic> Fwx;
+    DenseMatrix F;
+
+    RCP<const Basic> torque;
+
+    RCP<const Basic> alphas_eqn;
+  } direct;
 };
 
+DenseMatrix SumMatrices(DenseMatrix a) { return a; }
+
+template <typename... Args>
+DenseMatrix SumMatrices(DenseMatrix a, Args... args) {
+  DenseMatrix result = DenseMatrix(2, 1, {integer(0), integer(0)});
+
+  DenseMatrix b = SumMatrices(args...);
+  add_dense_dense(a, b, result);
+  return result;
+}
+
 class SwerveSimulation {
  public:
   SwerveSimulation() : drive_motor_(KrakenFOC()), steer_motor_(KrakenFOC()) {
@@ -193,6 +211,7 @@
 
     // Now, compute the accelerations due to the disturbance forces.
     DenseMatrix external_accel = DenseMatrix(2, 1, {div(fx, m_), div(fy, m_)});
+    DenseMatrix external_force = DenseMatrix(2, 1, {fx, fy});
 
     // And compute the physics contributions from each module.
     modules_[0] = ModulePhysics(
@@ -213,126 +232,34 @@
                                       div(robot_width_, integer(-2))}));
 
     // And convert them into the overall robot contribution.
-    DenseMatrix temp0 = DenseMatrix(2, 1);
-    DenseMatrix temp1 = DenseMatrix(2, 1);
-    DenseMatrix temp2 = DenseMatrix(2, 1);
-    accel_ = DenseMatrix(2, 1);
+    DenseMatrix net_full_force =
+        SumMatrices(modules_[0].full.F, modules_[1].full.F, modules_[2].full.F,
+                    modules_[3].full.F, external_force);
 
-    add_dense_dense(modules_[0].accel, external_accel, temp0);
-    add_dense_dense(temp0, modules_[1].accel, temp1);
-    add_dense_dense(temp1, modules_[2].accel, temp2);
-    add_dense_dense(temp2, modules_[3].accel, accel_);
+    DenseMatrix net_direct_force =
+        SumMatrices(modules_[0].direct.F, modules_[1].direct.F,
+                    modules_[2].direct.F, modules_[3].direct.F, external_force);
 
-    angular_accel_ =
-        add(div(moment, J_),
-            add(add(modules_[0].angular_accel, modules_[1].angular_accel),
-                add(modules_[2].angular_accel, modules_[3].angular_accel)));
+    full_accel_ = DenseMatrix(2, 1);
+    mul_dense_scalar(net_full_force, pow(m_, minus_one), full_accel_);
 
-    VLOG(1) << "accel(0, 0) = " << ccode(*accel_.get(0, 0));
-    VLOG(1) << "accel(1, 0) = " << ccode(*accel_.get(1, 0));
-    VLOG(1) << "angular_accel = " << ccode(*angular_accel_);
-  }
+    full_angular_accel_ = div(
+        add(moment, add(add(modules_[0].full.torque, modules_[1].full.torque),
+                        add(modules_[2].full.torque, modules_[3].full.torque))),
+        J_);
 
-  // Writes the physics out to the provided .py path.
-  void WritePy(std::string_view py_path) {
-    std::vector<std::string> result_py;
+    direct_accel_ = DenseMatrix(2, 1);
+    mul_dense_scalar(net_direct_force, pow(m_, minus_one), direct_accel_);
 
-    result_py.emplace_back("#!/usr/bin/python3");
-    result_py.emplace_back("");
-    result_py.emplace_back("import numpy");
-    result_py.emplace_back("");
+    direct_angular_accel_ =
+        div(add(moment,
+                add(add(modules_[0].direct.torque, modules_[1].direct.torque),
+                    add(modules_[2].direct.torque, modules_[3].direct.torque))),
+            J_);
 
-    result_py.emplace_back("def swerve_physics(t, X, U_func):");
-    result_py.emplace_back("    def atan2(y, x):");
-    result_py.emplace_back("        if x < 0:");
-    result_py.emplace_back("            return -numpy.atan2(y, x)");
-    result_py.emplace_back("        else:");
-    result_py.emplace_back("            return numpy.atan2(y, x)");
-    result_py.emplace_back("    sin = numpy.sin");
-    result_py.emplace_back("    cos = numpy.cos");
-    result_py.emplace_back("    fabs = numpy.fabs");
-
-    result_py.emplace_back("    result = numpy.empty([25, 1])");
-    result_py.emplace_back("    X = X.reshape(25, 1)");
-    result_py.emplace_back("    U = U_func(X)");
-    result_py.emplace_back("");
-
-    // Start by writing out variables matching each of the symbol names we use
-    // so we don't have to modify the computed equations too much.
-    for (size_t m = 0; m < kNumModules; ++m) {
-      result_py.emplace_back(
-          absl::Substitute("    thetas$0 = X[$1, 0]", m, m * 4));
-      result_py.emplace_back(
-          absl::Substitute("    omegas$0 = X[$1, 0]", m, m * 4 + 2));
-      result_py.emplace_back(
-          absl::Substitute("    omegad$0 = X[$1, 0]", m, m * 4 + 3));
-    }
-
-    result_py.emplace_back(
-        absl::Substitute("    theta = X[$0, 0]", kNumModules * 4 + 2));
-    result_py.emplace_back(
-        absl::Substitute("    vx = X[$0, 0]", kNumModules * 4 + 3));
-    result_py.emplace_back(
-        absl::Substitute("    vy = X[$0, 0]", kNumModules * 4 + 4));
-    result_py.emplace_back(
-        absl::Substitute("    omega = X[$0, 0]", kNumModules * 4 + 5));
-
-    result_py.emplace_back(
-        absl::Substitute("    fx = X[$0, 0]", kNumModules * 4 + 6));
-    result_py.emplace_back(
-        absl::Substitute("    fy = X[$0, 0]", kNumModules * 4 + 7));
-    result_py.emplace_back(
-        absl::Substitute("    moment = X[$0, 0]", kNumModules * 4 + 8));
-
-    // Now do the same for the inputs.
-    for (size_t m = 0; m < kNumModules; ++m) {
-      result_py.emplace_back(absl::Substitute("    Is$0 = U[$1, 0]", m, m * 2));
-      result_py.emplace_back(
-          absl::Substitute("    Id$0 = U[$1, 0]", m, m * 2 + 1));
-    }
-
-    result_py.emplace_back("");
-
-    // And then write out the derivative of each state.
-    for (size_t m = 0; m < kNumModules; ++m) {
-      result_py.emplace_back(
-          absl::Substitute("    result[$0, 0] = omegas$1", m * 4, m));
-      result_py.emplace_back(
-          absl::Substitute("    result[$0, 0] = omegad$1", m * 4 + 1, m));
-
-      result_py.emplace_back(absl::Substitute(
-          "    result[$0, 0] = $1", m * 4 + 2, ccode(*modules_[m].alphas_eqn)));
-      result_py.emplace_back(absl::Substitute(
-          "    result[$0, 0] = $1", m * 4 + 3, ccode(*modules_[m].alphad_eqn)));
-    }
-
-    result_py.emplace_back(
-        absl::Substitute("    result[$0, 0] = vx", kNumModules * 4));
-    result_py.emplace_back(
-        absl::Substitute("    result[$0, 0] = vy", kNumModules * 4 + 1));
-    result_py.emplace_back(
-        absl::Substitute("    result[$0, 0] = omega", kNumModules * 4 + 2));
-
-    result_py.emplace_back(absl::Substitute("    result[$0, 0] = $1",
-                                            kNumModules * 4 + 3,
-                                            ccode(*accel_.get(0, 0))));
-    result_py.emplace_back(absl::Substitute("    result[$0, 0] = $1",
-                                            kNumModules * 4 + 4,
-                                            ccode(*accel_.get(1, 0))));
-    result_py.emplace_back(absl::Substitute(
-        "    result[$0, 0] = $1", kNumModules * 4 + 5, ccode(*angular_accel_)));
-
-    result_py.emplace_back(
-        absl::Substitute("    result[$0, 0] = 0.0", kNumModules * 4 + 6));
-    result_py.emplace_back(
-        absl::Substitute("    result[$0, 0] = 0.0", kNumModules * 4 + 7));
-    result_py.emplace_back(
-        absl::Substitute("    result[$0, 0] = 0.0", kNumModules * 4 + 8));
-
-    result_py.emplace_back("");
-    result_py.emplace_back("    return result.reshape(25,)\n");
-
-    aos::util::WriteStringToFileOrDie(py_path, absl::StrJoin(result_py, "\n"));
+    VLOG(1) << "accel(0, 0) = " << ccode(*full_accel_.get(0, 0));
+    VLOG(1) << "accel(1, 0) = " << ccode(*full_accel_.get(1, 0));
+    VLOG(1) << "angular_accel = " << ccode(*full_angular_accel_);
   }
 
   // Writes the physics out to the provided .cc and .h path.
@@ -439,10 +366,12 @@
       result_cc.emplace_back(
           absl::Substitute("  result($0, 0) = omegad$1;", m * 4 + 1, m));
 
-      result_cc.emplace_back(absl::Substitute(
-          "  result($0, 0) = $1;", m * 4 + 2, ccode(*modules_[m].alphas_eqn)));
-      result_cc.emplace_back(absl::Substitute(
-          "  result($0, 0) = $1;", m * 4 + 3, ccode(*modules_[m].alphad_eqn)));
+      result_cc.emplace_back(
+          absl::Substitute("  result($0, 0) = $1;", m * 4 + 2,
+                           ccode(*modules_[m].full.alphas_eqn)));
+      result_cc.emplace_back(
+          absl::Substitute("  result($0, 0) = $1;", m * 4 + 3,
+                           ccode(*modules_[m].full.alphad_eqn)));
     }
 
     result_cc.emplace_back(
@@ -454,12 +383,13 @@
 
     result_cc.emplace_back(absl::Substitute("  result($0, 0) = $1;",
                                             kNumModules * 4 + 3,
-                                            ccode(*accel_.get(0, 0))));
+                                            ccode(*full_accel_.get(0, 0))));
     result_cc.emplace_back(absl::Substitute("  result($0, 0) = $1;",
                                             kNumModules * 4 + 4,
-                                            ccode(*accel_.get(1, 0))));
-    result_cc.emplace_back(absl::Substitute(
-        "  result($0, 0) = $1;", kNumModules * 4 + 5, ccode(*angular_accel_)));
+                                            ccode(*full_accel_.get(1, 0))));
+    result_cc.emplace_back(absl::Substitute("  result($0, 0) = $1;",
+                                            kNumModules * 4 + 5,
+                                            ccode(*full_angular_accel_)));
 
     result_cc.emplace_back(
         absl::Substitute("  result($0, 0) = 0.0;", kNumModules * 4 + 6));
@@ -522,6 +452,52 @@
     }
   }
 
+  void WriteCasadiVelocityVariables(std::vector<std::string> *result_py) {
+    result_py->emplace_back("    sin = casadi.sin");
+    result_py->emplace_back("    sign = casadi.sign");
+    result_py->emplace_back("    cos = casadi.cos");
+    result_py->emplace_back("    atan2 = casadi.atan2");
+    result_py->emplace_back("    fmax = casadi.fmax");
+    result_py->emplace_back("    fabs = casadi.fabs");
+
+    // Start by writing out variables matching each of the symbol names we use
+    // so we don't have to modify the computed equations too much.
+    for (size_t m = 0; m < kNumModules; ++m) {
+      result_py->emplace_back(
+          absl::Substitute("    thetas$0 = X[$1, 0]", m, m * 2 + 0));
+      result_py->emplace_back(
+          absl::Substitute("    omegas$0 = X[$1, 0]", m, m * 2 + 1));
+    }
+
+    result_py->emplace_back(
+        absl::Substitute("    theta = X[$0, 0]", kNumModules * 2 + 0));
+    result_py->emplace_back(
+        absl::Substitute("    vx = X[$0, 0]", kNumModules * 2 + 1));
+    result_py->emplace_back(
+        absl::Substitute("    vy = X[$0, 0]", kNumModules * 2 + 2));
+    result_py->emplace_back(
+        absl::Substitute("    omega = X[$0, 0]", kNumModules * 2 + 3));
+
+    // result_py->emplace_back(
+    // absl::Substitute("    fx = X[$0, 0]", kNumModules * 3 + 4));
+    // result_py->emplace_back(
+    // absl::Substitute("    fy = X[$0, 0]", kNumModules * 3 + 5));
+    // result_py->emplace_back(
+    // absl::Substitute("    moment = X[$0, 0]", kNumModules * 3 + 6));
+    //
+    result_py->emplace_back("    fx = 0");
+    result_py->emplace_back("    fy = 0");
+    result_py->emplace_back("    moment = 0");
+
+    // Now do the same for the inputs.
+    for (size_t m = 0; m < kNumModules; ++m) {
+      result_py->emplace_back(
+          absl::Substitute("    Is$0 = U[$1, 0]", m, m * 2));
+      result_py->emplace_back(
+          absl::Substitute("    Id$0 = U[$1, 0]", m, m * 2 + 1));
+    }
+  }
+
   // Writes the physics out to the provided .cc and .h path.
   void WriteCasadi(std::string_view py_path) {
     std::vector<std::string> result_py;
@@ -529,12 +505,74 @@
     // Write out the header.
     result_py.emplace_back("#!/usr/bin/python3");
     result_py.emplace_back("");
-    result_py.emplace_back("import casadi");
+    result_py.emplace_back("import casadi, numpy");
     result_py.emplace_back("");
     result_py.emplace_back(absl::Substitute("WHEEL_RADIUS = $0", ccode(*rw_)));
     result_py.emplace_back(
         absl::Substitute("ROBOT_WIDTH = $0", ccode(*robot_width_)));
     result_py.emplace_back(absl::Substitute("CASTER = $0", ccode(*caster_)));
+    result_py.emplace_back("STATE_THETAS0 = 0");
+    result_py.emplace_back("STATE_THETAD0 = 1");
+    result_py.emplace_back("STATE_OMEGAS0 = 2");
+    result_py.emplace_back("STATE_OMEGAD0 = 3");
+    result_py.emplace_back("STATE_THETAS1 = 4");
+    result_py.emplace_back("STATE_THETAD1 = 5");
+    result_py.emplace_back("STATE_OMEGAS1 = 6");
+    result_py.emplace_back("STATE_OMEGAD1 = 7");
+    result_py.emplace_back("STATE_THETAS2 = 8");
+    result_py.emplace_back("STATE_THETAD2 = 9");
+    result_py.emplace_back("STATE_OMEGAS2 = 10");
+    result_py.emplace_back("STATE_OMEGAD2 = 11");
+    result_py.emplace_back("STATE_THETAS3 = 12");
+    result_py.emplace_back("STATE_THETAD3 = 13");
+    result_py.emplace_back("STATE_OMEGAS3 = 14");
+    result_py.emplace_back("STATE_OMEGAD3 = 15");
+    result_py.emplace_back("STATE_X = 16");
+    result_py.emplace_back("STATE_Y = 17");
+    result_py.emplace_back("STATE_THETA = 18");
+    result_py.emplace_back("STATE_VX = 19");
+    result_py.emplace_back("STATE_VY = 20");
+    result_py.emplace_back("STATE_OMEGA = 21");
+    result_py.emplace_back("STATE_FX = 22");
+    result_py.emplace_back("STATE_FY = 23");
+    result_py.emplace_back("STATE_MOMENT = 24");
+    result_py.emplace_back("NUM_STATES = 25");
+    result_py.emplace_back("");
+    result_py.emplace_back("VELOCITY_STATE_THETAS0 = 0");
+    result_py.emplace_back("VELOCITY_STATE_OMEGAS0 = 1");
+    result_py.emplace_back("VELOCITY_STATE_THETAS1 = 2");
+    result_py.emplace_back("VELOCITY_STATE_OMEGAS1 = 3");
+    result_py.emplace_back("VELOCITY_STATE_THETAS2 = 4");
+    result_py.emplace_back("VELOCITY_STATE_OMEGAS2 = 5");
+    result_py.emplace_back("VELOCITY_STATE_THETAS3 = 6");
+    result_py.emplace_back("VELOCITY_STATE_OMEGAS3 = 7");
+    result_py.emplace_back("VELOCITY_STATE_THETA = 8");
+    result_py.emplace_back("VELOCITY_STATE_VX = 9");
+    result_py.emplace_back("VELOCITY_STATE_VY = 10");
+    result_py.emplace_back("VELOCITY_STATE_OMEGA = 11");
+    // result_py.emplace_back("VELOCITY_STATE_FX = 16");
+    // result_py.emplace_back("VELOCITY_STATE_FY = 17");
+    // result_py.emplace_back("VELOCITY_STATE_MOMENT = 18");
+    result_py.emplace_back("NUM_VELOCITY_STATES = 12");
+    result_py.emplace_back("");
+    result_py.emplace_back("def to_velocity_state(X):");
+    result_py.emplace_back("    return numpy.array([");
+    result_py.emplace_back("        [X[STATE_THETAS0, 0]],");
+    result_py.emplace_back("        [X[STATE_OMEGAS0, 0]],");
+    result_py.emplace_back("        [X[STATE_THETAS1, 0]],");
+    result_py.emplace_back("        [X[STATE_OMEGAS1, 0]],");
+    result_py.emplace_back("        [X[STATE_THETAS2, 0]],");
+    result_py.emplace_back("        [X[STATE_OMEGAS2, 0]],");
+    result_py.emplace_back("        [X[STATE_THETAS3, 0]],");
+    result_py.emplace_back("        [X[STATE_OMEGAS3, 0]],");
+    result_py.emplace_back("        [X[STATE_THETA, 0]],");
+    result_py.emplace_back("        [X[STATE_VX, 0]],");
+    result_py.emplace_back("        [X[STATE_VY, 0]],");
+    result_py.emplace_back("        [X[STATE_OMEGA, 0]],");
+    // result_py.emplace_back("        [X[STATE_FX, 0]],");
+    // result_py.emplace_back("        [X[STATE_FY, 0]],");
+    // result_py.emplace_back("        [X[STATE_MOMENT, 0]],");
+    result_py.emplace_back("    ])");
     result_py.emplace_back("");
 
     result_py.emplace_back("# Returns the derivative of our state vector");
@@ -544,7 +582,7 @@
     result_py.emplace_back("#  thetas3, thetad3, omegas3, omegad3,");
     result_py.emplace_back("#  x, y, theta, vx, vy, omega,");
     result_py.emplace_back("#  Fx, Fy, Moment]");
-    result_py.emplace_back("def swerve_physics(X, U):");
+    result_py.emplace_back("def swerve_full_dynamics(X, U):");
     WriteCasadiVariables(&result_py);
 
     result_py.emplace_back("");
@@ -558,10 +596,12 @@
       result_py.emplace_back(
           absl::Substitute("    result[$0, 0] = omegad$1", m * 4 + 1, m));
 
-      result_py.emplace_back(absl::Substitute(
-          "    result[$0, 0] = $1", m * 4 + 2, ccode(*modules_[m].alphas_eqn)));
-      result_py.emplace_back(absl::Substitute(
-          "    result[$0, 0] = $1", m * 4 + 3, ccode(*modules_[m].alphad_eqn)));
+      result_py.emplace_back(
+          absl::Substitute("    result[$0, 0] = $1", m * 4 + 2,
+                           ccode(*modules_[m].full.alphas_eqn)));
+      result_py.emplace_back(
+          absl::Substitute("    result[$0, 0] = $1", m * 4 + 3,
+                           ccode(*modules_[m].full.alphad_eqn)));
     }
 
     result_py.emplace_back(
@@ -573,12 +613,13 @@
 
     result_py.emplace_back(absl::Substitute("    result[$0, 0] = $1",
                                             kNumModules * 4 + 3,
-                                            ccode(*accel_.get(0, 0))));
+                                            ccode(*full_accel_.get(0, 0))));
     result_py.emplace_back(absl::Substitute("    result[$0, 0] = $1",
                                             kNumModules * 4 + 4,
-                                            ccode(*accel_.get(1, 0))));
-    result_py.emplace_back(absl::Substitute(
-        "    result[$0, 0] = $1", kNumModules * 4 + 5, ccode(*angular_accel_)));
+                                            ccode(*full_accel_.get(1, 0))));
+    result_py.emplace_back(absl::Substitute("    result[$0, 0] = $1",
+                                            kNumModules * 4 + 5,
+                                            ccode(*full_angular_accel_)));
 
     result_py.emplace_back(
         absl::Substitute("    result[$0, 0] = 0.0", kNumModules * 4 + 6));
@@ -591,6 +632,54 @@
     result_py.emplace_back(
         "    return casadi.Function('xdot', [X, U], [result])");
 
+    result_py.emplace_back("");
+
+    result_py.emplace_back("# Returns the derivative of our state vector");
+    result_py.emplace_back("# [thetas0, omegas0,");
+    result_py.emplace_back("#  thetas1, omegas1,");
+    result_py.emplace_back("#  thetas2, omegas2,");
+    result_py.emplace_back("#  thetas3, omegas3,");
+    result_py.emplace_back("#  theta, vx, vy, omega]");
+    result_py.emplace_back("def velocity_swerve_physics(X, U):");
+    WriteCasadiVelocityVariables(&result_py);
+
+    result_py.emplace_back("");
+    result_py.emplace_back(
+        "    result = casadi.SX.sym('result', NUM_VELOCITY_STATES, 1)");
+    result_py.emplace_back("");
+
+    // And then write out the derivative of each state.
+    for (size_t m = 0; m < kNumModules; ++m) {
+      result_py.emplace_back(
+          absl::Substitute("    result[$0, 0] = omegas$1", m * 2 + 0, m));
+      result_py.emplace_back(
+          absl::Substitute("    result[$0, 0] = $1", m * 2 + 1,
+                           ccode(*modules_[m].direct.alphas_eqn)));
+    }
+    result_py.emplace_back(
+        absl::Substitute("    result[$0, 0] = omega", kNumModules * 2 + 0));
+
+    result_py.emplace_back(absl::Substitute("    result[$0, 0] = $1",
+                                            kNumModules * 2 + 1,
+                                            ccode(*direct_accel_.get(0, 0))));
+    result_py.emplace_back(absl::Substitute("    result[$0, 0] = $1",
+                                            kNumModules * 2 + 2,
+                                            ccode(*direct_accel_.get(1, 0))));
+    result_py.emplace_back(absl::Substitute("    result[$0, 0] = $1",
+                                            kNumModules * 2 + 3,
+                                            ccode(*direct_angular_accel_)));
+
+    // result_py.emplace_back(
+    // absl::Substitute("    result[$0, 0] = 0.0", kNumModules * 3 + 4));
+    // result_py.emplace_back(
+    // absl::Substitute("    result[$0, 0] = 0.0", kNumModules * 3 + 5));
+    // result_py.emplace_back(
+    // absl::Substitute("    result[$0, 0] = 0.0", kNumModules * 3 + 6));
+
+    result_py.emplace_back("");
+    result_py.emplace_back(
+        "    return casadi.Function('xdot', [X, U], [result])");
+
     DefineVector2dFunction(
         "contact_patch_velocity",
         "# Returns the velocity of the wheel in global coordinates.",
@@ -624,20 +713,22 @@
     DefineScalarFunction(
         "module_angular_accel",
         "Returns the angular acceleration of the robot due to the ith wheel",
-        [](const Module &m) { return ccode(*m.angular_accel); }, &result_py);
+        [this](const Module &m) { return ccode(*div(m.full.torque, Js_)); },
+        &result_py);
 
     DefineVector2dFunction(
         "wheel_force",
         "Returns the force on the wheel in steer module coordinates",
         [](const Module &m, int dimension) {
-          return ccode(*std::vector<RCP<const Basic>>{m.Fwx, m.Fwy}[dimension]);
+          return ccode(
+              *std::vector<RCP<const Basic>>{m.full.Fwx, m.Fwy}[dimension]);
         },
         &result_py);
 
     DefineVector2dFunction(
         "F", "Returns the force on the wheel in absolute coordinates",
         [](const Module &m, int dimension) {
-          return ccode(*m.F.get(dimension, 0));
+          return ccode(*m.full.F.get(dimension, 0));
         },
         &result_py);
 
@@ -711,6 +802,16 @@
  private:
   static constexpr uint8_t kNumModules = 4;
 
+  RCP<const Basic> SteerAccel(RCP<const Basic> Fwx, RCP<const Basic> Ms,
+                              RCP<const Basic> Is) {
+    RCP<const Basic> lhms =
+        mul(add(neg(wb_), mul(add(rs_, rp_), sub(integer(1), div(rb1_, rp_)))),
+            mul(div(rw_, rb2_), neg(Fwx)));
+    RCP<const Basic> lhs = add(add(Ms, div(mul(Kts_, Is), Gs_)), lhms);
+    RCP<const Basic> rhs = add(Jsm_, div(div(Js_, Gs_), Gs_));
+    return simplify(div(lhs, rhs));
+  }
+
   Module ModulePhysics(const int m, DenseMatrix mounting_location) {
     VLOG(1) << "Solving module " << m;
 
@@ -726,11 +827,10 @@
 
     result.thetas = symbol(absl::StrFormat("thetas%u", m));
     result.omegas = symbol(absl::StrFormat("omegas%u", m));
-    result.alphas = symbol(absl::StrFormat("alphas%u", m));
+    RCP<const Symbol> alphas = symbol(absl::StrFormat("alphas%u", m));
 
-    result.thetad = symbol(absl::StrFormat("thetad%u", m));
     result.omegad = symbol(absl::StrFormat("omegad%u", m));
-    result.alphad = symbol(absl::StrFormat("alphad%u", m));
+    RCP<const Symbol> alphad = symbol(absl::StrFormat("alphad%u", m));
 
     // Velocity of the module in field coordinates
     DenseMatrix robot_velocity = DenseMatrix(2, 1, {vx_, vy_});
@@ -790,7 +890,7 @@
     VLOG(1);
     VLOG(1) << "Slip ratio " << result.slip_ratio->__str__();
 
-    result.Fwx = simplify(mul(Cx_, result.slip_ratio));
+    result.full.Fwx = simplify(mul(Cx_, result.slip_ratio));
     result.Fwy = simplify(mul(Cy_, result.slip_angle));
 
     // The self-aligning moment needs to flip when the module flips direction.
@@ -802,35 +902,34 @@
     VLOG(1);
     VLOG(1) << "Ms " << result.Ms->__str__();
     VLOG(1);
-    VLOG(1) << "Fwx " << result.Fwx->__str__();
+    VLOG(1) << "full.Fwx " << result.full.Fwx->__str__();
     VLOG(1);
     VLOG(1) << "Fwy " << result.Fwy->__str__();
 
-    // alphas = ...
-    RCP<const Basic> lhms =
-        mul(add(neg(wb_), mul(add(rs_, rp_), sub(integer(1), div(rb1_, rp_)))),
-            mul(div(rw_, rb2_), neg(result.Fwx)));
-    RCP<const Basic> lhs =
-        add(add(result.Ms, div(mul(Kts_, result.Is), Gs_)), lhms);
-    RCP<const Basic> rhs = add(Jsm_, div(div(Js_, Gs_), Gs_));
-    RCP<const Basic> accel_steer_eqn = simplify(div(lhs, rhs));
+    // -K_td * Id / Gd + Fwx * rw = 0
+    // Fwx = K_td * Id / Gd / rw
+    result.direct.Fwx = mul(Ktd_, div(result.Id, mul(Gd_, rw_)));
+
+    result.direct.alphas_eqn =
+        SteerAccel(result.direct.Fwx, result.Ms, result.Is);
+
+    // d/dt omegas = ...
+    result.full.alphas_eqn = SteerAccel(result.full.Fwx, result.Ms, result.Is);
 
     VLOG(1);
-    VLOG(1) << result.alphas->__str__() << " = " << accel_steer_eqn->__str__();
+    VLOG(1) << alphas->__str__() << " = " << result.full.alphas_eqn->__str__();
 
-    lhs = sub(mul(sub(div(add(rp_, rs_), rp_), integer(1)), result.omegas),
-              mul(Gd1_, mul(Gd2_, omegamd)));
-    RCP<const Basic> dplanitary_eqn = sub(mul(Gd3_, lhs), result.omegad);
+    RCP<const Basic> lhs =
+        sub(mul(sub(div(add(rp_, rs_), rp_), integer(1)), alphas),
+            mul(Gd1_, mul(Gd2_, alphamd)));
+    RCP<const Basic> ddplanitary_eqn = sub(mul(Gd3_, lhs), alphad);
 
-    lhs = sub(mul(sub(div(add(rp_, rs_), rp_), integer(1)), result.alphas),
-              mul(Gd1_, mul(Gd2_, alphamd)));
-    RCP<const Basic> ddplanitary_eqn = sub(mul(Gd3_, lhs), result.alphad);
+    RCP<const Basic> full_drive_eqn =
+        sub(add(mul(neg(Jdm_), div(alphamd, Gd_)),
+                mul(Ktd_, div(neg(result.Id), Gd_))),
+            mul(neg(result.full.Fwx), rw_));
 
-    RCP<const Basic> drive_eqn = sub(add(mul(neg(Jdm_), div(alphamd, Gd_)),
-                                         mul(Ktd_, div(neg(result.Id), Gd_))),
-                                     mul(neg(result.Fwx), rw_));
-
-    VLOG(1) << "drive_eqn: " << drive_eqn->__str__();
+    VLOG(1) << "full_drive_eqn: " << full_drive_eqn->__str__();
 
     // Substitute in ddplanitary_eqn so we get rid of alphamd
     map_basic_basic map;
@@ -838,39 +937,38 @@
     RCP<const Set> solve_solution = solve(ddplanitary_eqn, alphamd, reals);
     map[alphamd] = solve_solution->get_args()[1]->get_args()[0];
     VLOG(1) << "temp: " << solve_solution->__str__();
-    RCP<const Basic> drive_eqn_subs = drive_eqn->subs(map);
+    RCP<const Basic> drive_eqn_subs = full_drive_eqn->subs(map);
 
     map.clear();
-    map[result.alphas] = accel_steer_eqn;
+    map[alphas] = result.full.alphas_eqn;
     RCP<const Basic> drive_eqn_subs2 = drive_eqn_subs->subs(map);
     RCP<const Basic> drive_eqn_subs3 = simplify(drive_eqn_subs2);
-    VLOG(1) << "drive_eqn simplified: " << drive_eqn_subs3->__str__();
+    VLOG(1) << "full_drive_eqn simplified: " << drive_eqn_subs3->__str__();
 
-    solve_solution = solve(drive_eqn_subs3, result.alphad, reals);
+    solve_solution = solve(drive_eqn_subs3, alphad, reals);
 
-    RCP<const Basic> drive_accel =
+    result.full.alphad_eqn =
         simplify(solve_solution->get_args()[1]->get_args()[0]);
-    VLOG(1) << "drive_accel: " << drive_accel->__str__();
+    VLOG(1) << "drive_accel: " << result.full.alphad_eqn->__str__();
 
     // Compute the resulting force from the module.
-    result.F = DenseMatrix(2, 1);
+    result.full.F = DenseMatrix(2, 1);
     mul_dense_dense(R(add(theta_, result.thetas)),
-                    DenseMatrix(2, 1, {result.Fwx, result.Fwy}), result.F);
+                    DenseMatrix(2, 1, {result.full.Fwx, result.Fwy}),
+                    result.full.F);
+    result.full.torque = force_cross(result.mounting_location, result.full.F);
 
-    RCP<const Basic> torque = force_cross(result.mounting_location, result.F);
-    result.accel = DenseMatrix(2, 1);
-    mul_dense_scalar(result.F, pow(m_, minus_one), result.accel);
-    result.angular_accel = div(torque, J_);
-    VLOG(1);
-    VLOG(1) << "angular_accel = " << result.angular_accel->__str__();
+    result.direct.F = DenseMatrix(2, 1);
+    mul_dense_dense(R(add(theta_, result.thetas)),
+                    DenseMatrix(2, 1, {result.direct.Fwx, result.Fwy}),
+                    result.direct.F);
+    result.direct.torque =
+        force_cross(result.mounting_location, result.direct.F);
 
     VLOG(1);
-    VLOG(1) << "accel(0, 0) = " << result.accel.get(0, 0)->__str__();
-    VLOG(1);
-    VLOG(1) << "accel(1, 0) = " << result.accel.get(1, 0)->__str__();
+    VLOG(1) << "full torque = " << result.full.torque->__str__();
+    VLOG(1) << "direct torque = " << result.full.torque->__str__();
 
-    result.alphad_eqn = drive_accel;
-    result.alphas_eqn = accel_steer_eqn;
     return result;
   }
 
@@ -938,8 +1036,10 @@
 
   std::array<Module, kNumModules> modules_;
 
-  DenseMatrix accel_;
-  RCP<const Basic> angular_accel_;
+  DenseMatrix full_accel_;
+  RCP<const Basic> full_angular_accel_;
+  DenseMatrix direct_accel_;
+  RCP<const Basic> direct_angular_accel_;
 };
 
 }  // namespace frc971::control_loops::swerve
@@ -954,9 +1054,6 @@
     sim.Write(absl::GetFlag(FLAGS_cc_output_path),
               absl::GetFlag(FLAGS_h_output_path));
   }
-  if (!absl::GetFlag(FLAGS_py_output_path).empty()) {
-    sim.WritePy(absl::GetFlag(FLAGS_py_output_path));
-  }
   if (!absl::GetFlag(FLAGS_casadi_py_output_path).empty()) {
     sim.WriteCasadi(absl::GetFlag(FLAGS_casadi_py_output_path));
   }
diff --git a/frc971/control_loops/swerve/physics_test.py b/frc971/control_loops/swerve/physics_test.py
index c7d60e9..664d0aa 100644
--- a/frc971/control_loops/swerve/physics_test.py
+++ b/frc971/control_loops/swerve/physics_test.py
@@ -15,8 +15,34 @@
 class TestSwervePhysics(unittest.TestCase):
     I = numpy.zeros((8, 1))
 
+    def to_velocity_state(self, X):
+        return dynamics.to_velocity_state(X)
+
+    def swerve_full_dynamics(self, X, U, skip_compare=False):
+        X_velocity = self.to_velocity_state(X)
+        Xdot = self.position_swerve_full_dynamics(X, U)
+        if not skip_compare:
+            velocity_states = self.to_velocity_state(Xdot)
+            velocity_physics = self.velocity_swerve_physics(X_velocity, U)
+            self.assertLess(
+                numpy.linalg.norm(velocity_states - velocity_physics),
+                2e-2,
+                msg=
+                f'Norm failed, full physics -> {velocity_states.T}, velocity physics -> {velocity_physics}, difference -> {velocity_physics - velocity_states}',
+            )
+
+        return Xdot
+
     def wrap(self, python_module):
-        self.swerve_physics = utils.wrap(python_module.swerve_physics)
+        self.position_swerve_full_dynamics = utils.wrap(
+            python_module.swerve_full_dynamics)
+
+        evaluated_fn = python_module.velocity_swerve_physics(
+            casadi.SX.sym("X", dynamics.NUM_VELOCITY_STATES, 1),
+            casadi.SX.sym("U", 8, 1))
+        self.velocity_swerve_physics = lambda X, U: numpy.array(
+            evaluated_fn(X, U))
+
         self.contact_patch_velocity = [
             utils.wrap_module(python_module.contact_patch_velocity, i)
             for i in range(4)
@@ -172,7 +198,7 @@
             velocity=numpy.array([[1.0], [0.0]]),
             module_angles=[-0.001, -0.001, 0.001, 0.001],
         )
-        xdot_equal = self.swerve_physics(X, self.I)
+        xdot_equal = self.swerve_full_dynamics(X, self.I)
 
         self.assertGreater(xdot_equal[2, 0], 0.0)
         self.assertAlmostEqual(xdot_equal[3, 0], 0.0, places=1)
@@ -199,7 +225,7 @@
             velocity=numpy.array([[1.0], [0.0]]),
             module_angles=[0.01, 0.01, 0.01, 0.01],
         )
-        xdot_left = self.swerve_physics(X, self.I)
+        xdot_left = self.swerve_full_dynamics(X, self.I)
 
         self.assertLess(xdot_left[2, 0], -0.05)
         self.assertLess(xdot_left[3, 0], 0.0)
@@ -227,7 +253,7 @@
             velocity=numpy.array([[1.0], [0.0]]),
             module_angles=[-0.01, -0.01, -0.01, -0.01],
         )
-        xdot_right = self.swerve_physics(X, self.I)
+        xdot_right = self.swerve_full_dynamics(X, self.I)
 
         self.assertGreater(xdot_right[2, 0], 0.05)
         self.assertLess(xdot_right[3, 0], 0.0)
@@ -264,7 +290,7 @@
             ],
             drive_wheel_velocity=-1,
         )
-        xdot_equal = self.swerve_physics(X, self.I)
+        xdot_equal = self.swerve_full_dynamics(X, self.I)
 
         self.assertGreater(xdot_equal[2, 0], 0.0, msg="Steering backwards")
         self.assertAlmostEqual(xdot_equal[3, 0], 0.0, places=1)
@@ -292,7 +318,7 @@
             module_angles=[numpy.pi + 0.01] * 4,
             drive_wheel_velocity=-1,
         )
-        xdot_left = self.swerve_physics(X, self.I)
+        xdot_left = self.swerve_full_dynamics(X, self.I)
 
         self.assertLess(xdot_left[2, 0], -0.05)
         self.assertGreater(xdot_left[3, 0], 0.0)
@@ -321,7 +347,7 @@
             drive_wheel_velocity=-1,
             module_angles=[-0.01 + numpy.pi] * 4,
         )
-        xdot_right = self.swerve_physics(X, self.I)
+        xdot_right = self.swerve_full_dynamics(X, self.I)
 
         self.assertGreater(xdot_right[2, 0], 0.05)
         self.assertGreater(xdot_right[3, 0], 0.0)
@@ -358,7 +384,7 @@
             ],
             drive_wheel_velocity=-1,
         )
-        xdot_equal = self.swerve_physics(X, self.I)
+        xdot_equal = self.swerve_full_dynamics(X, self.I)
 
         self.assertLess(xdot_equal[2, 0], 0.0, msg="Steering backwards")
         self.assertAlmostEqual(xdot_equal[3, 0], 0.0, places=1)
@@ -386,7 +412,7 @@
             module_angles=[numpy.pi + 0.01] * 4,
             drive_wheel_velocity=-1,
         )
-        xdot_left = self.swerve_physics(X, self.I)
+        xdot_left = self.swerve_full_dynamics(X, self.I)
 
         self.assertGreater(xdot_left[2, 0], -0.05)
         self.assertGreater(xdot_left[3, 0], 0.0)
@@ -415,7 +441,7 @@
             drive_wheel_velocity=-1,
             module_angles=[-0.01 + numpy.pi] * 4,
         )
-        xdot_right = self.swerve_physics(X, self.I)
+        xdot_right = self.swerve_full_dynamics(X, self.I)
 
         self.assertLess(xdot_right[2, 0], 0.05)
         self.assertGreater(xdot_right[3, 0], 0.0)
@@ -445,7 +471,7 @@
 
             X = utils.state_vector()
             robot_equal = wheel_force(X, self.I)
-            xdot_equal = self.swerve_physics(X, self.I)
+            xdot_equal = self.swerve_full_dynamics(X, self.I)
             self.assertEqual(robot_equal[0, 0], 0.0)
             self.assertEqual(robot_equal[1, 0], 0.0)
             self.assertEqual(xdot_equal[2 + 4 * i], 0.0)
@@ -454,7 +480,9 @@
             # Robot is moving faster than the wheels, it should decelerate.
             X = utils.state_vector(dx=0.01)
             robot_faster = wheel_force(X, self.I)
-            xdot_faster = self.swerve_physics(X, self.I)
+            xdot_faster = self.swerve_full_dynamics(X,
+                                                    self.I,
+                                                    skip_compare=True)
             self.assertLess(robot_faster[0, 0], -0.1)
             self.assertEqual(robot_faster[1, 0], 0.0)
             self.assertGreater(xdot_faster[3 + 4 * i], 0.0)
@@ -462,7 +490,9 @@
             # Robot is now going slower than the wheels.  It should accelerate.
             X = utils.state_vector(dx=-0.01)
             robot_slower = wheel_force(X, self.I)
-            xdot_slower = self.swerve_physics(X, self.I)
+            xdot_slower = self.swerve_full_dynamics(X,
+                                                    self.I,
+                                                    skip_compare=True)
             self.assertGreater(robot_slower[0, 0], 0.1)
             self.assertEqual(robot_slower[1, 0], 0.0)
             self.assertLess(xdot_slower[3 + 4 * i], 0.0)
diff --git a/frc971/control_loops/swerve/swerve_notes.tex b/frc971/control_loops/swerve/swerve_notes.tex
index 99b5e56..bf83dcc 100644
--- a/frc971/control_loops/swerve/swerve_notes.tex
+++ b/frc971/control_loops/swerve/swerve_notes.tex
@@ -354,11 +354,11 @@
     \tau_md + \tau_d = -F_{wx}r_w\\
     \dot\theta_d = G_{d3}\left(\left(G_{carrier} - 1\right)\dot\theta_s - \dot\theta_{md}G_{motor-to-planet}\right) \label{the_one}
 \end{gather}
-The $-1$ in equation \eqref{the_one} comes from the additional contirbution of the wrapping around the gear with the change of coordinates.\\
-which we can expand into
+The $-1$ in equation \eqref{the_one} comes from the additional contirbution of the wrapping around the gear with the change of coordinates.
+We can expand this into
 \begin{gather}
     M_s + \frac{K_tI_s}{G_s} + \left(-w_b+\left(r_s+r_p\right)\left(1-\frac{r_{b1}}{r_p}\right)\right)\frac{r_w}{r_{b2}}\left(-F_{wx}\right) = \left(J_s + \frac{J_{ms}}{G^2_s}\right)\ddot\theta_s\\
-    \left(\frac{-J_{md}\ddot\theta_{md}}{G_d}\right)+\frac{K_{t}I_d}{G_d} = -F_{wx}r_w\\
+    \frac{J_{md}\ddot\theta_{md}}{G_d}+\frac{K_{t}I_d}{G_d} = F_{wx}r_w \label{md_ddot}\\
     \dot\theta_d = G_{d3}\left(\left(\frac{r_p+r_s}{r_p}-1\right)\dot\theta_s - \dot\theta_{md}G_{d1}G_{d2}\right)
 \end{gather}
 where each G represents a separate gear ratio.
@@ -378,6 +378,20 @@
 \begin{gather}
     \ddot{\theta} = \frac{\Sigma\left(\harpoon{r} \times \harpoon{F}_{mod}\right) + \tau_{d}}{J_{robot}}
 \end{gather}
+\subsection{Simplified longitudinal dynamics}
+
+The time constants involved with the longitudinal dynamics are significantly faster than the time constants for accelerating a robot.
+This makes the equations pretty stiff and hard to work with.
+Ignoring the mass of the wheel, the time constant for a Kraken is around $4ms$.
+This also makes it hard on the MPC solver.  Therefore, it is useful to have a simplified version of the physics which solves for the acceleration force directly as a function of drive current.
+
+\eqref{md_ddot} then simplifies to 
+\begin{gather}
+\frac{K_{t}I_d}{G_d} = -F_{wx}r_w\\
+\end{gather}
+
+TODO(austin): Need to document how the position dynamics work for an EKF.
+
 \newpage
 \printbibliography
 \end{document}