Add JAX cost, and test to verify correctness

Turns out there's a bug in the casadi too, so fix that while we are
here.  We were using the unrotated radius for computing the cost
function torque instead of the rotated radius.

Split the MPC code out into a separate file as well to make it easier to
pull into other things.

Change-Id: Iee6c9999fa8b6a91d6963af4edba5cf92a085e9b
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/BUILD b/frc971/control_loops/swerve/BUILD
index 513bf1b..2c3ad50 100644
--- a/frc971/control_loops/swerve/BUILD
+++ b/frc971/control_loops/swerve/BUILD
@@ -227,11 +227,15 @@
     main = "physics_test.py",
     target_compatible_with = ["@platforms//cpu:x86_64"],
     deps = [
+        ":casadi_velocity_mpc_lib",
         ":dynamics",
         ":jax_dynamics",
         ":physics_test_utils",
+        "@pip//absl_py",
         "@pip//casadi",
+        "@pip//matplotlib",
         "@pip//numpy",
+        "@pip//pygobject",
         "@pip//scipy",
     ],
 )
@@ -248,12 +252,28 @@
     main = "physics_test.py",
     target_compatible_with = ["@platforms//cpu:x86_64"],
     deps = [
+        ":casadi_velocity_mpc_lib",
         ":dynamics",
         ":jax_dynamics",
         ":physics_test_utils",
+        "@pip//absl_py",
+        "@pip//casadi",
+        "@pip//matplotlib",
+        "@pip//numpy",
+        "@pip//pygobject",
+        "@pip//scipy",
+    ],
+)
+
+py_binary(
+    name = "casadi_velocity_mpc_lib",
+    srcs = [
+        "casadi_velocity_mpc_lib.py",
+    ],
+    deps = [
+        ":dynamics",
         "@pip//casadi",
         "@pip//numpy",
-        "@pip//scipy",
     ],
 )
 
@@ -263,9 +283,8 @@
         "casadi_velocity_mpc.py",
     ],
     deps = [
-        ":dynamics",
+        ":casadi_velocity_mpc_lib",
         "@pip//absl_py",
-        "@pip//casadi",
         "@pip//matplotlib",
         "@pip//numpy",
         "@pip//pygobject",
diff --git a/frc971/control_loops/swerve/casadi_velocity_mpc.py b/frc971/control_loops/swerve/casadi_velocity_mpc.py
index 517de5a..d62f7ed 100644
--- a/frc971/control_loops/swerve/casadi_velocity_mpc.py
+++ b/frc971/control_loops/swerve/casadi_velocity_mpc.py
@@ -1,6 +1,7 @@
 #!/usr/bin/env python3
 
 from frc971.control_loops.swerve import dynamics
+from frc971.control_loops.swerve.casadi_velocity_mpc_lib import MPC
 import pickle
 import matplotlib.pyplot as pyplot
 import matplotlib
@@ -23,296 +24,10 @@
 flags.DEFINE_bool('pickle', False, 'Write optimization results.')
 flags.DEFINE_string('outputdir', None, 'Directory to write problem results to')
 
-matplotlib.use("GTK3Agg")
-
 # Full print level on ipopt. Piping to a file and using a filter or search method is suggested
 # grad_x prints out the gradient at each iteration in the following sequence: U0, X1, U1, etc.
-full_debug = False
-
-
-class MPC(object):
-
-    def __init__(self, solver='fatrop'):
-        self.fdot = dynamics.swerve_full_dynamics(
-            casadi.SX.sym("X", dynamics.NUM_STATES, 1),
-            casadi.SX.sym("U", 8, 1))
-        self.velocity_fdot = dynamics.velocity_swerve_physics(
-            casadi.SX.sym("X", dynamics.NUM_VELOCITY_STATES, 1),
-            casadi.SX.sym("U", 8, 1))
-
-        self.wrapped_swerve_physics = lambda X, U: numpy.array(self.fdot(X, U))
-
-        self.dt = 0.005
-
-        # TODO(austin): Do we need a disturbance torque per module to account for friction?
-        # Do it only in the observer/post?
-
-        self.force = [
-            dynamics.F(i, casadi.SX.sym("X", 25, 1), casadi.SX.sym("U", 8, 1))
-            for i in range(4)
-        ]
-        self.force_vel = [
-            dynamics.F_vel(i,
-                           casadi.SX.sym("X", dynamics.NUM_VELOCITY_STATES, 1),
-                           casadi.SX.sym("U", 8, 1)) for i in range(4)
-        ]
-        self.slip_angle = [
-            dynamics.slip_angle(i, casadi.SX.sym("X", 25, 1),
-                                casadi.SX.sym("U", 8, 1)) for i in range(4)
-        ]
-        self.mounting_location = [
-            dynamics.mounting_location(
-                i, casadi.SX.sym("X", dynamics.NUM_VELOCITY_STATES, 1),
-                casadi.SX.sym("U", 8, 1)) for i in range(4)
-        ]
-        self.torque_cross = self.torque_cross_func(casadi.SX.sym("r", 2, 1),
-                                                   casadi.SX.sym("F", 2, 1))
-        self.force_cross = self.force_cross_func(casadi.SX.sym("Tau", 1, 1),
-                                                 casadi.SX.sym("r", 2, 1))
-        self.next_X = self.make_physics()
-        self.cost = self.make_cost()
-
-        self.N = 200
-
-        # Start with an empty nonlinear program.
-        self.w = []
-        self.lbw = []
-        self.ubw = []
-        J = 0
-        self.g = []
-        self.lbg = []
-        self.ubg = []
-
-        self.X0 = casadi.MX.sym('X0', dynamics.NUM_VELOCITY_STATES)
-
-        # We care about the linear and angular velocities only.
-        self.R = casadi.MX.sym('R', 3)
-
-        # Make Xn and U for each step.  fatrop wants us to interleave the control variables and
-        # states so that it can produce a banded/structured problem which it can solve a lot
-        # faster.
-        Xn_variables = []
-        U_variables = []
-        for i in range(self.N):
-            U_variables.append(casadi.MX.sym(f'U{i}', 8))
-
-            if i == 0:
-                continue
-
-            Xn_variables.append(
-                casadi.MX.sym(f'X{i}', dynamics.NUM_VELOCITY_STATES))
-
-        Xn = casadi.horzcat(*Xn_variables)
-        U = casadi.horzcat(*U_variables)
-
-        # printme(number) is the debug.
-        Xk_begin = casadi.horzcat(self.X0, Xn)
-        Xk_end = self.next_X.map(self.N, "thread")(Xk_begin, U)
-        J = casadi.sum2(self.cost.map(self.N, "thread")(Xk_end, U, self.R))
-
-        # Put U and Xn interleaved into w to go fast.
-        for i in range(self.N):
-            self.w += [U_variables[i]]
-            self.ubw += [100] * 8
-            self.lbw += [-100] * 8
-
-            if i == self.N - 1:
-                continue
-
-            self.w += [Xn_variables[i]]
-            self.ubw += [casadi.inf] * dynamics.NUM_VELOCITY_STATES
-            self.lbw += [-casadi.inf] * dynamics.NUM_VELOCITY_STATES
-
-        self.g += [
-            casadi.reshape(Xn - Xk_end[:, 0:(self.N - 1)],
-                           dynamics.NUM_VELOCITY_STATES * (self.N - 1), 1)
-        ]
-
-        self.lbg += [0] * dynamics.NUM_VELOCITY_STATES * (self.N - 1)
-        self.ubg += [0] * dynamics.NUM_VELOCITY_STATES * (self.N - 1)
-
-        prob = {
-            'f': J,
-            # lbx <= x <= ubx
-            'x': casadi.vertcat(*self.w),
-            # lbg <= g(x, p) <= ubg
-            'g': casadi.vertcat(*self.g),
-            # Input parameters (initial position + goal)
-            'p': casadi.vertcat(self.X0, self.R),
-        }
-
-        compiler = "ccache clang"
-        flags = ["-O3"]
-        jit_options = {
-            "flags": flags,
-            "verbose": False,
-            "compiler": compiler,
-            "temp_suffix": False,
-        }
-
-        if solver == 'fatrop':
-            equality = [
-                True
-                for _ in range(dynamics.NUM_VELOCITY_STATES * (self.N - 1))
-            ]
-            options = {
-                "jit": True,
-                "jit_cleanup": False,
-                "jit_temp_suffix": False,
-                "compiler": "shell",
-                "jit_options": jit_options,
-                "structure_detection": "auto",
-                "fatrop": {
-                    "tol": 1e-7
-                },
-                "debug": True,
-                "equality": equality,
-            }
-        else:
-            options = {
-                "jit": True,
-                "jit_cleanup": False,
-                "jit_temp_suffix": False,
-                "compiler": "shell",
-                "jit_options": jit_options,
-            }
-            if full_debug:
-                options["jit"] = False
-                options["ipopt"] = {
-                    "print_level": 12,
-                }
-
-        self.solver = casadi.nlpsol('solver', solver, prob, options)
-
-    # TODO(austin): Vary the number of sub steps to be more short term and fewer long term?
-    def make_physics(self):
-        X0 = casadi.MX.sym('X0', dynamics.NUM_VELOCITY_STATES)
-        U = casadi.MX.sym('U', 8)
-
-        X = X0
-        M = 2  # RK4 steps per interval
-        DT = self.dt / M
-
-        for j in range(M):
-            k1 = self.velocity_fdot(X, U)
-            k2 = self.velocity_fdot(X + DT / 2 * k1, U)
-            k3 = self.velocity_fdot(X + DT / 2 * k2, U)
-            k4 = self.velocity_fdot(X + DT * k3, U)
-            X = X + DT / 6 * (k1 + 2 * k2 + 2 * k3 + k4)
-
-        return casadi.Function("next_X", [X0, U], [X])
-
-    def make_cost(self):
-        # TODO(austin): tune cost fn?
-        # Do we want to penalize slipping tires?
-        # Do we want to penalize powers unevenly applied across the motors?
-        # Need to do some simulations to see what works well.
-
-        X = casadi.MX.sym('X', dynamics.NUM_VELOCITY_STATES)
-        U = casadi.MX.sym('U', 8)
-        R = casadi.MX.sym('R', 3)
-
-        J = 0
-        vnorm = casadi.sqrt(R[0]**2.0 + R[1]**2.0)
-
-        vnormx = casadi.if_else(vnorm > 0.0001, R[0] / vnorm, 1.0)
-        vnormy = casadi.if_else(vnorm > 0.0001, R[1] / vnorm, 0.0)
-
-        vperpx = casadi.if_else(vnorm > 0.0001, -vnormy, 0.0)
-        vperpy = casadi.if_else(vnorm > 0.0001, vnormx, 1.0)
-
-        # TODO(austin): Do we want to do something more special for 0?
-
-        # cost velocity a lot more in the perpendicular direction to allow tire to spin up
-        # we also only want to get moving in the correct direction as fast as possible
-        J += 75 * ((R[0] - X[dynamics.VELOCITY_STATE_VX]) * vnormx +
-                   (R[1] - X[dynamics.VELOCITY_STATE_VY]) * vnormy)**2.0
-
-        J += 1500 * ((R[0] - X[dynamics.VELOCITY_STATE_VX]) * vperpx +
-                     (R[1] - X[dynamics.VELOCITY_STATE_VY]) * vperpy)**2.0
-
-        J += 1000 * (R[2] - X[dynamics.VELOCITY_STATE_OMEGA])**2.0
-
-        kSteerPositionGain = 0
-        kSteerVelocityGain = 0.10
-        J += kSteerPositionGain * (X[dynamics.VELOCITY_STATE_THETAS0])**2.0
-        J += kSteerVelocityGain * (X[dynamics.VELOCITY_STATE_OMEGAS0])**2.0
-
-        J += kSteerPositionGain * (X[dynamics.VELOCITY_STATE_THETAS1])**2.0
-        J += kSteerVelocityGain * (X[dynamics.VELOCITY_STATE_OMEGAS1])**2.0
-
-        J += kSteerPositionGain * (X[dynamics.VELOCITY_STATE_THETAS2])**2.0
-        J += kSteerVelocityGain * (X[dynamics.VELOCITY_STATE_OMEGAS2])**2.0
-
-        J += kSteerPositionGain * (X[dynamics.VELOCITY_STATE_THETAS3])**2.0
-        J += kSteerVelocityGain * (X[dynamics.VELOCITY_STATE_OMEGAS3])**2.0
-
-        # cost variance of the force by a tire and the expected average force and torque on it
-        total_force = self.force_vel[0](X, U)
-        total_torque = self.torque_cross(self.mounting_location[0](X, U),
-                                         self.force_vel[0](X, U))
-        for i in range(3):
-            total_force += self.force_vel[i + 1](X, U)
-            total_torque += self.torque_cross(
-                self.mounting_location[i + 1](X, U), self.force_vel[i + 1](X,
-                                                                           U))
-
-        total_force /= 4
-        total_torque /= 4
-        for i in range(4):
-            f_diff = (total_force +
-                      self.force_cross(total_torque, self.mounting_location[i]
-                                       (X, U))) - self.force_vel[i](X, U)
-            J += 0.01 * (f_diff[0, 0]**2.0 + f_diff[1, 0]**2.0)
-
-        # TODO(austin): Don't penalize torque steering current.
-        for i in range(4):
-            Is = U[2 * i + 0]
-            Id = U[2 * i + 1]
-            # Steer, cost it a lot less than drive to be more agressive in steering
-            J += ((Is + dynamics.STEER_CURRENT_COUPLING_FACTOR * Id)**
-                  2.0) / 100000.0
-            # Drive
-            J += Id * Id / 1000.0
-
-        return casadi.Function("Jn", [X, U, R], [J])
-
-    def torque_cross_func(self, r, F):
-        result = casadi.SX.sym('Tau', 1, 1)
-        result[0, 0] = r[0, 0] * F[1, 0] - r[1, 0] * F[0, 0]
-        return casadi.Function('Tau', [r, F], [result])
-
-    def force_cross_func(self, Tau, r):
-        result = casadi.SX.sym('F', 2, 1)
-        result[0, 0] = -r[1, 0] * Tau[0, 0] / casadi.norm_2(r)**2.0
-        result[1, 0] = r[0, 0] * Tau[0, 0] / casadi.norm_2(r)**2.0
-        return casadi.Function('F', [Tau, r], [result])
-
-    def solve(self, p, seed=None):
-        if seed is None:
-            seed = []
-
-            for i in range(self.N):
-                seed += [0, 0] * 4
-                if i < self.N - 1:
-                    seed += list(p[:dynamics.NUM_VELOCITY_STATES, 0])
-
-        return self.solver(x0=seed,
-                           lbx=self.lbw,
-                           ubx=self.ubw,
-                           lbg=self.lbg,
-                           ubg=self.ubg,
-                           p=casadi.DM(p))
-
-    def unpack_u(self, sol, i):
-        return sol['x'].full().flatten()[
-            (8 + dynamics.NUM_VELOCITY_STATES) *
-            i:((8 + dynamics.NUM_VELOCITY_STATES) * i + 8)]
-
-    def unpack_x(self, sol, i):
-        return sol['x'].full().flatten(
-        )[8 + (8 + dynamics.NUM_VELOCITY_STATES) *
-          (i - 1):(8 + dynamics.NUM_VELOCITY_STATES) * i]
+flags.DEFINE_bool('full_debug', False,
+                  'If true, turn on all the debugging in the solver.')
 
 
 class Solver(object):
@@ -423,10 +138,12 @@
 
                 last_time = time.time()
 
-        print(f"Tool {overall_time} seconds overall to solve.")
+        print(f"Took {overall_time} seconds overall to solve.")
 
 
 def main(argv):
+    matplotlib.use("GTK3Agg")
+
     if FLAGS.outputdir:
         os.chdir(FLAGS.outputdir)
 
@@ -457,7 +174,7 @@
     R_goal[1, 0] = FLAGS.vy
     R_goal[2, 0] = FLAGS.omega
 
-    mpc = MPC(solver='fatrop') if not full_debug else MPC(solver='ipopt')
+    mpc = MPC(solver='fatrop') if not FLAGS.full_debug else MPC(solver='ipopt')
     solver = Solver()
     if not FLAGS.compileonly:
         results = solver.solve(mpc=mpc,
diff --git a/frc971/control_loops/swerve/casadi_velocity_mpc_lib.py b/frc971/control_loops/swerve/casadi_velocity_mpc_lib.py
new file mode 100644
index 0000000..f464268
--- /dev/null
+++ b/frc971/control_loops/swerve/casadi_velocity_mpc_lib.py
@@ -0,0 +1,280 @@
+#!/usr/bin/env python3
+
+from frc971.control_loops.swerve import dynamics
+import casadi
+import numpy
+
+
+class MPC(object):
+
+    def __init__(self, solver='fatrop', jit=True):
+        self.fdot = dynamics.swerve_full_dynamics(
+            casadi.SX.sym("X", dynamics.NUM_STATES, 1),
+            casadi.SX.sym("U", 8, 1))
+        self.velocity_fdot = dynamics.velocity_swerve_physics(
+            casadi.SX.sym("X", dynamics.NUM_VELOCITY_STATES, 1),
+            casadi.SX.sym("U", 8, 1))
+
+        self.wrapped_swerve_physics = lambda X, U: numpy.array(self.fdot(X, U))
+
+        self.dt = 0.005
+
+        # TODO(austin): Do we need a disturbance torque per module to account for friction?
+        # Do it only in the observer/post?
+
+        self.force = [
+            dynamics.F(i, casadi.SX.sym("X", 25, 1), casadi.SX.sym("U", 8, 1))
+            for i in range(4)
+        ]
+        self.force_vel = [
+            dynamics.F_vel(i,
+                           casadi.SX.sym("X", dynamics.NUM_VELOCITY_STATES, 1),
+                           casadi.SX.sym("U", 8, 1)) for i in range(4)
+        ]
+        self.slip_angle = [
+            dynamics.slip_angle(i, casadi.SX.sym("X", 25, 1),
+                                casadi.SX.sym("U", 8, 1)) for i in range(4)
+        ]
+        self.rotated_mounting_location = [
+            dynamics.rotated_mounting_location(
+                i, casadi.SX.sym("X", dynamics.NUM_VELOCITY_STATES, 1),
+                casadi.SX.sym("U", 8, 1)) for i in range(4)
+        ]
+        self.torque_cross = self.torque_cross_func(casadi.SX.sym("r", 2, 1),
+                                                   casadi.SX.sym("F", 2, 1))
+        self.force_cross = self.force_cross_func(casadi.SX.sym("Tau", 1, 1),
+                                                 casadi.SX.sym("r", 2, 1))
+        self.next_X = self.make_physics()
+        self.cost = self.make_cost()
+
+        self.N = 200
+
+        # Start with an empty nonlinear program.
+        self.w = []
+        self.lbw = []
+        self.ubw = []
+        J = 0
+        self.g = []
+        self.lbg = []
+        self.ubg = []
+
+        self.X0 = casadi.MX.sym('X0', dynamics.NUM_VELOCITY_STATES)
+
+        # We care about the linear and angular velocities only.
+        self.R = casadi.MX.sym('R', 3)
+
+        # Make Xn and U for each step.  fatrop wants us to interleave the control variables and
+        # states so that it can produce a banded/structured problem which it can solve a lot
+        # faster.
+        Xn_variables = []
+        U_variables = []
+        for i in range(self.N):
+            U_variables.append(casadi.MX.sym(f'U{i}', 8))
+
+            if i == 0:
+                continue
+
+            Xn_variables.append(
+                casadi.MX.sym(f'X{i}', dynamics.NUM_VELOCITY_STATES))
+
+        Xn = casadi.horzcat(*Xn_variables)
+        U = casadi.horzcat(*U_variables)
+
+        # printme(number) is the debug.
+        Xk_begin = casadi.horzcat(self.X0, Xn)
+        Xk_end = self.next_X.map(self.N, "thread")(Xk_begin, U)
+        J = casadi.sum2(self.cost.map(self.N, "thread")(Xk_end, U, self.R))
+
+        # Put U and Xn interleaved into w to go fast.
+        for i in range(self.N):
+            self.w += [U_variables[i]]
+            self.ubw += [100] * 8
+            self.lbw += [-100] * 8
+
+            if i == self.N - 1:
+                continue
+
+            self.w += [Xn_variables[i]]
+            self.ubw += [casadi.inf] * dynamics.NUM_VELOCITY_STATES
+            self.lbw += [-casadi.inf] * dynamics.NUM_VELOCITY_STATES
+
+        self.g += [
+            casadi.reshape(Xn - Xk_end[:, 0:(self.N - 1)],
+                           dynamics.NUM_VELOCITY_STATES * (self.N - 1), 1)
+        ]
+
+        self.lbg += [0] * dynamics.NUM_VELOCITY_STATES * (self.N - 1)
+        self.ubg += [0] * dynamics.NUM_VELOCITY_STATES * (self.N - 1)
+
+        prob = {
+            'f': J,
+            # lbx <= x <= ubx
+            'x': casadi.vertcat(*self.w),
+            # lbg <= g(x, p) <= ubg
+            'g': casadi.vertcat(*self.g),
+            # Input parameters (initial position + goal)
+            'p': casadi.vertcat(self.X0, self.R),
+        }
+
+        compiler = "ccache clang"
+        flags = ["-O3"]
+        jit_options = {
+            "flags": flags,
+            "verbose": True,
+            "compiler": compiler,
+            "temp_suffix": False,
+        }
+
+        if solver == 'fatrop':
+            equality = [
+                True
+                for _ in range(dynamics.NUM_VELOCITY_STATES * (self.N - 1))
+            ]
+            options = {
+                "jit": jit,
+                "jit_cleanup": False,
+                "jit_temp_suffix": False,
+                "compiler": "shell",
+                "jit_options": jit_options,
+                "structure_detection": "auto",
+                "fatrop": {
+                    "tol": 1e-7
+                },
+                "debug": True,
+                "equality": equality,
+            }
+        else:
+            options = {
+                "jit": jit,
+                "jit_cleanup": False,
+                "jit_temp_suffix": False,
+                "compiler": "shell",
+                "jit_options": jit_options,
+            }
+            if FLAGS.full_debug:
+                options["jit"] = False
+                options["ipopt"] = {
+                    "print_level": 12,
+                }
+
+        self.solver = casadi.nlpsol('solver', solver, prob, options)
+
+    # TODO(austin): Vary the number of sub steps to be more short term and fewer long term?
+    def make_physics(self):
+        X0 = casadi.MX.sym('X0', dynamics.NUM_VELOCITY_STATES)
+        U = casadi.MX.sym('U', 8)
+
+        X = X0
+        M = 2  # RK4 steps per interval
+        DT = self.dt / M
+
+        for j in range(M):
+            k1 = self.velocity_fdot(X, U)
+            k2 = self.velocity_fdot(X + DT / 2 * k1, U)
+            k3 = self.velocity_fdot(X + DT / 2 * k2, U)
+            k4 = self.velocity_fdot(X + DT * k3, U)
+            X = X + DT / 6 * (k1 + 2 * k2 + 2 * k3 + k4)
+
+        return casadi.Function("next_X", [X0, U], [X])
+
+    def make_cost(self):
+        # TODO(austin): tune cost fn?
+        # Need to do some simulations to see what works well.
+
+        X = casadi.MX.sym('X', dynamics.NUM_VELOCITY_STATES)
+        U = casadi.MX.sym('U', 8)
+        R = casadi.MX.sym('R', 3)
+
+        J = 0
+        vnorm = casadi.sqrt(R[0]**2.0 + R[1]**2.0)
+
+        vnormx = casadi.if_else(vnorm > 0.0001, R[0] / vnorm, 1.0)
+        vnormy = casadi.if_else(vnorm > 0.0001, R[1] / vnorm, 0.0)
+
+        vperpx = casadi.if_else(vnorm > 0.0001, -vnormy, 0.0)
+        vperpy = casadi.if_else(vnorm > 0.0001, vnormx, 1.0)
+
+        # TODO(austin): Do we want to do something more special for 0?
+
+        # cost velocity a lot more in the perpendicular direction to allow tire to spin up
+        # we also only want to get moving in the correct direction as fast as possible
+        J += 75 * ((R[0] - X[dynamics.VELOCITY_STATE_VX]) * vnormx +
+                   (R[1] - X[dynamics.VELOCITY_STATE_VY]) * vnormy)**2.0
+
+        J += 1500 * ((R[0] - X[dynamics.VELOCITY_STATE_VX]) * vperpx +
+                     (R[1] - X[dynamics.VELOCITY_STATE_VY]) * vperpy)**2.0
+
+        J += 1000 * (R[2] - X[dynamics.VELOCITY_STATE_OMEGA])**2.0
+
+        kSteerVelocityGain = 0.10
+        J += kSteerVelocityGain * (X[dynamics.VELOCITY_STATE_OMEGAS0])**2.0
+        J += kSteerVelocityGain * (X[dynamics.VELOCITY_STATE_OMEGAS1])**2.0
+        J += kSteerVelocityGain * (X[dynamics.VELOCITY_STATE_OMEGAS2])**2.0
+        J += kSteerVelocityGain * (X[dynamics.VELOCITY_STATE_OMEGAS3])**2.0
+
+        # cost variance of the force by a tire and the expected average force and torque on it
+        total_force = self.force_vel[0](X, U)
+        total_torque = self.torque_cross(
+            self.rotated_mounting_location[0](X, U), self.force_vel[0](X, U))
+        for i in range(3):
+            total_force += self.force_vel[i + 1](X, U)
+            total_torque += self.torque_cross(
+                self.rotated_mounting_location[i + 1](X, U),
+                self.force_vel[i + 1](X, U))
+
+        total_force /= 4
+        total_torque /= 4
+        for i in range(4):
+            f_diff = (total_force + self.force_cross(
+                total_torque, self.rotated_mounting_location[i]
+                (X, U))) - self.force_vel[i](X, U)
+            J += 0.01 * (f_diff[0, 0]**2.0 + f_diff[1, 0]**2.0)
+
+        # TODO(austin): Don't penalize torque steering current.
+        for i in range(4):
+            Is = U[2 * i + 0]
+            Id = U[2 * i + 1]
+            # Steer, cost it a lot less than drive to be more agressive in steering
+            J += ((Is + dynamics.STEER_CURRENT_COUPLING_FACTOR * Id)**
+                  2.0) / 100000.0
+            # Drive
+            J += Id * Id / 1000.0
+
+        return casadi.Function("Jn", [X, U, R], [J])
+
+    def torque_cross_func(self, r, F):
+        result = casadi.SX.sym('Tau', 1, 1)
+        result[0, 0] = r[0, 0] * F[1, 0] - r[1, 0] * F[0, 0]
+        return casadi.Function('Tau', [r, F], [result])
+
+    def force_cross_func(self, Tau, r):
+        result = casadi.SX.sym('F', 2, 1)
+        result[0, 0] = -r[1, 0] * Tau[0, 0] / casadi.norm_2(r)**2.0
+        result[1, 0] = r[0, 0] * Tau[0, 0] / casadi.norm_2(r)**2.0
+        return casadi.Function('F', [Tau, r], [result])
+
+    def solve(self, p, seed=None):
+        if seed is None:
+            seed = []
+
+            for i in range(self.N):
+                seed += [0, 0] * 4
+                if i < self.N - 1:
+                    seed += list(p[:dynamics.NUM_VELOCITY_STATES, 0])
+
+        return self.solver(x0=seed,
+                           lbx=self.lbw,
+                           ubx=self.ubw,
+                           lbg=self.lbg,
+                           ubg=self.ubg,
+                           p=casadi.DM(p))
+
+    def unpack_u(self, sol, i):
+        return sol['x'].full().flatten()[
+            (8 + dynamics.NUM_VELOCITY_STATES) *
+            i:((8 + dynamics.NUM_VELOCITY_STATES) * i + 8)]
+
+    def unpack_x(self, sol, i):
+        return sol['x'].full().flatten(
+        )[8 + (8 + dynamics.NUM_VELOCITY_STATES) *
+          (i - 1):(8 + dynamics.NUM_VELOCITY_STATES) * i]
diff --git a/frc971/control_loops/swerve/generate_physics.cc b/frc971/control_loops/swerve/generate_physics.cc
index 7ea3db5..de52d79 100644
--- a/frc971/control_loops/swerve/generate_physics.cc
+++ b/frc971/control_loops/swerve/generate_physics.cc
@@ -70,6 +70,7 @@
 // State per module.
 struct Module {
   DenseMatrix mounting_location;
+  DenseMatrix rotated_mounting_location;
 
   RCP<const Symbol> Is;
 
@@ -938,6 +939,14 @@
         },
         &result_py);
 
+    DefineVector2dVelocityFunction(
+        "rotated_mounting_location",
+        "Returns the mounting location of wheel in field aligned coordinates",
+        [](const Module &m, int dimension) {
+          return ccode(*m.rotated_mounting_location.get(dimension, 0));
+        },
+        &result_py);
+
     DefineScalarFunction(
         "Ms", "Returns the self aligning moment of the ith wheel",
         [this](const Module &m) {
@@ -1189,17 +1198,18 @@
                     DenseMatrix(2, 1, {result.full.Fwx, result.Fwy}),
                     result.full.F);
 
-    DenseMatrix rotated_mounting_location = DenseMatrix(2, 1);
+    result.rotated_mounting_location = DenseMatrix(2, 1);
     mul_dense_dense(R(theta_), result.mounting_location,
-                    rotated_mounting_location);
-    result.full.torque = force_cross(rotated_mounting_location, result.full.F);
+                    result.rotated_mounting_location);
+    result.full.torque =
+        force_cross(result.rotated_mounting_location, result.full.F);
 
     result.direct.F = DenseMatrix(2, 1);
     mul_dense_dense(R(add(theta_, result.thetas)),
                     DenseMatrix(2, 1, {result.direct.Fwx, result.Fwy}),
                     result.direct.F);
     result.direct.torque =
-        force_cross(rotated_mounting_location, result.direct.F);
+        force_cross(result.rotated_mounting_location, result.direct.F);
 
     VLOG(1);
     VLOG(1) << "full torque = " << result.full.torque->__str__();
diff --git a/frc971/control_loops/swerve/jax_dynamics.py b/frc971/control_loops/swerve/jax_dynamics.py
index 6e1305d..19a48de 100644
--- a/frc971/control_loops/swerve/jax_dynamics.py
+++ b/frc971/control_loops/swerve/jax_dynamics.py
@@ -352,3 +352,73 @@
         X[STATE_VY],
         X[STATE_OMEGA],
     ])
+
+
+def mpc_cost(coefficients: CoefficientsType, X, U, goal):
+    J = 0
+
+    rnorm = jax.numpy.linalg.norm(goal[0:2])
+
+    vnorm = jax.lax.select(rnorm > 0.0001, goal[0:2] / rnorm,
+                           jax.numpy.array([1.0, 0.0]))
+    vperp = jax.lax.select(rnorm > 0.0001,
+                           jax.numpy.array([-vnorm[1], vnorm[0]]),
+                           jax.numpy.array([0.0, 1.0]))
+
+    velocity_error = goal[0:2] - X[VELOCITY_STATE_VX:VELOCITY_STATE_VY + 1]
+
+    # TODO(austin): Do we want to do something more special for 0?
+
+    J += 75 * (jax.numpy.dot(velocity_error, vnorm)**2.0)
+    J += 1500 * (jax.numpy.dot(velocity_error, vperp)**2.0)
+    J += 1000 * (goal[2] - X[VELOCITY_STATE_OMEGA])**2.0
+
+    kSteerVelocityGain = 0.10
+    J += kSteerVelocityGain * (X[VELOCITY_STATE_OMEGAS0])**2.0
+    J += kSteerVelocityGain * (X[VELOCITY_STATE_OMEGAS1])**2.0
+    J += kSteerVelocityGain * (X[VELOCITY_STATE_OMEGAS2])**2.0
+    J += kSteerVelocityGain * (X[VELOCITY_STATE_OMEGAS3])**2.0
+
+    mounting_locations = jax.numpy.array(
+        [[coefficients.robot_width / 2.0, coefficients.robot_width / 2.0],
+         [-coefficients.robot_width / 2.0, coefficients.robot_width / 2.0],
+         [-coefficients.robot_width / 2.0, -coefficients.robot_width / 2.0],
+         [coefficients.robot_width / 2.0, -coefficients.robot_width / 2.0]])
+
+    Rtheta = R(X[VELOCITY_STATE_THETA])
+    _, F0, torque0 = velocity_module_physics(coefficients, Rtheta, 0,
+                                             mounting_locations[0], X, U)
+    _, F1, torque1 = velocity_module_physics(coefficients, Rtheta, 1,
+                                             mounting_locations[1], X, U)
+    _, F2, torque2 = velocity_module_physics(coefficients, Rtheta, 2,
+                                             mounting_locations[2], X, U)
+    _, F3, torque3 = velocity_module_physics(coefficients, Rtheta, 3,
+                                             mounting_locations[3], X, U)
+
+    forces = [F0, F1, F2, F3]
+
+    F = (F0 + F1 + F2 + F3)
+    torque = (torque0 + torque1 + torque2 + torque3)
+
+    def force_cross(torque, r):
+        r_squared_norm = jax.numpy.inner(r, r)
+
+        return jax.numpy.array(
+            [-r[1] * torque / r_squared_norm, r[0] * torque / r_squared_norm])
+
+    # TODO(austin): Are these penalties reasonable?  Do they give us a decent time constant?
+    for i in range(4):
+        desired_force = F / 4.0 + force_cross(
+            torque / 4.0, Rtheta @ mounting_locations[i, :])
+        force_error = desired_force - forces[i]
+        J += 0.01 * jax.numpy.inner(force_error, force_error)
+
+    for i in range(4):
+        Is = U[2 * i + 0]
+        Id = U[2 * i + 1]
+        # Steer
+        J += ((Is + STEER_CURRENT_COUPLING_FACTOR * Id)**2.0) / 100000.0
+        # Drive
+        J += (Id**2.0) / 1000.0
+
+    return J
diff --git a/frc971/control_loops/swerve/physics_test.py b/frc971/control_loops/swerve/physics_test.py
index 1da2a31..bc366fa 100644
--- a/frc971/control_loops/swerve/physics_test.py
+++ b/frc971/control_loops/swerve/physics_test.py
@@ -18,6 +18,7 @@
 from frc971.control_loops.swerve import dynamics
 from frc971.control_loops.swerve import nocaster_dynamics
 from frc971.control_loops.swerve import physics_test_utils as utils
+from frc971.control_loops.swerve import casadi_velocity_mpc_lib
 from frc971.control_loops.swerve import jax_dynamics
 from frc971.control_loops.swerve.cpp_dynamics import swerve_dynamics as cpp_dynamics
 
@@ -707,8 +708,56 @@
         Xdot_rot = self.swerve_full_dynamics(X_rot, steer_I, skip_compare=True)
 
         self.assertGreater(Xdot[dynamics.STATE_OMEGA, 0], 0.0)
-        self.assertAlmostEquals(Xdot[dynamics.STATE_OMEGA, 0],
-                                Xdot_rot[dynamics.STATE_OMEGA, 0])
+        self.assertAlmostEqual(Xdot[dynamics.STATE_OMEGA, 0],
+                               Xdot_rot[dynamics.STATE_OMEGA, 0])
+
+    def test_cost_equality(self):
+        """Tests that the casadi and jax cost functions match."""
+        mpc = casadi_velocity_mpc_lib.MPC(jit=False)
+        cost = mpc.make_cost()
+
+        for i in range(10):
+            X = numpy.random.uniform(size=(dynamics.NUM_VELOCITY_STATES, ))
+            U = numpy.random.uniform(low=-10, high=10, size=(8, ))
+            R = numpy.random.uniform(low=-1, high=1, size=(3, ))
+
+            J = numpy.array(cost(X, U, R))[0, 0]
+            jax_J = jax_dynamics.mpc_cost(self.coefficients, X, U, R)
+
+            self.assertAlmostEqual(J, jax_J)
+
+        R = jax.numpy.array([0.0, 0.0, 1.0])
+
+        # Now try spinning in place and make sure the cost doesn't change.
+        # This tells us if we got our rotations right.
+        steer_I = numpy.array([(i % 2) * 20 for i in range(8)])
+
+        X = utils.state_vector(velocity=numpy.array([[0.0], [0.0]]),
+                               omega=0.0,
+                               module_angles=[
+                                   3 * numpy.pi / 4.0, -3 * numpy.pi / 4.0,
+                                   -numpy.pi / 4.0, numpy.pi / 4.0
+                               ],
+                               drive_wheel_velocity=1.0)
+
+        jax_J_orig = jax_dynamics.mpc_cost(self.coefficients,
+                                           self.to_velocity_state(X)[:, 0],
+                                           steer_I, R)
+
+        X_rotated = utils.state_vector(velocity=numpy.array([[0.0], [0.0]]),
+                                       omega=0.0,
+                                       theta=numpy.pi,
+                                       module_angles=[
+                                           3 * numpy.pi / 4.0,
+                                           -3 * numpy.pi / 4.0,
+                                           -numpy.pi / 4.0, numpy.pi / 4.0
+                                       ],
+                                       drive_wheel_velocity=1.0)
+        jax_J_rotated = jax_dynamics.mpc_cost(
+            self.coefficients,
+            self.to_velocity_state(X_rotated)[:, 0], steer_I, R)
+
+        self.assertAlmostEqual(jax_J_orig, jax_J_rotated)
 
     def test_cpp_consistency(self):
         """Tests that the C++ physics are consistent with the Python physics."""
diff --git a/frc971/control_loops/swerve/velocity_controller/physics.py b/frc971/control_loops/swerve/velocity_controller/physics.py
index d6515fe..b0da0b4 100644
--- a/frc971/control_loops/swerve/velocity_controller/physics.py
+++ b/frc971/control_loops/swerve/velocity_controller/physics.py
@@ -1,7 +1,7 @@
 import jax, numpy
 from functools import partial
 from absl import logging
-from frc971.control_loops.swerve import dynamics, jax_dynamics
+from frc971.control_loops.swerve import jax_dynamics
 from frc971.control_loops.python import controls
 from flax.typing import PRNGKey
 
diff --git a/frc971/control_loops/swerve/velocity_controller/train.py b/frc971/control_loops/swerve/velocity_controller/train.py
index f5ba092..b1ad9a8 100644
--- a/frc971/control_loops/swerve/velocity_controller/train.py
+++ b/frc971/control_loops/swerve/velocity_controller/train.py
@@ -20,7 +20,6 @@
 from functools import partial
 import flashbax
 from jax.experimental.ode import odeint
-from frc971.control_loops.swerve import dynamics
 import orbax.checkpoint
 from frc971.control_loops.swerve.velocity_controller.model import *
 from frc971.control_loops.swerve.velocity_controller.physics import *