Add JAX cost, and test to verify correctness

Turns out there's a bug in the casadi too, so fix that while we are
here.  We were using the unrotated radius for computing the cost
function torque instead of the rotated radius.

Split the MPC code out into a separate file as well to make it easier to
pull into other things.

Change-Id: Iee6c9999fa8b6a91d6963af4edba5cf92a085e9b
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/physics_test.py b/frc971/control_loops/swerve/physics_test.py
index 1da2a31..bc366fa 100644
--- a/frc971/control_loops/swerve/physics_test.py
+++ b/frc971/control_loops/swerve/physics_test.py
@@ -18,6 +18,7 @@
 from frc971.control_loops.swerve import dynamics
 from frc971.control_loops.swerve import nocaster_dynamics
 from frc971.control_loops.swerve import physics_test_utils as utils
+from frc971.control_loops.swerve import casadi_velocity_mpc_lib
 from frc971.control_loops.swerve import jax_dynamics
 from frc971.control_loops.swerve.cpp_dynamics import swerve_dynamics as cpp_dynamics
 
@@ -707,8 +708,56 @@
         Xdot_rot = self.swerve_full_dynamics(X_rot, steer_I, skip_compare=True)
 
         self.assertGreater(Xdot[dynamics.STATE_OMEGA, 0], 0.0)
-        self.assertAlmostEquals(Xdot[dynamics.STATE_OMEGA, 0],
-                                Xdot_rot[dynamics.STATE_OMEGA, 0])
+        self.assertAlmostEqual(Xdot[dynamics.STATE_OMEGA, 0],
+                               Xdot_rot[dynamics.STATE_OMEGA, 0])
+
+    def test_cost_equality(self):
+        """Tests that the casadi and jax cost functions match."""
+        mpc = casadi_velocity_mpc_lib.MPC(jit=False)
+        cost = mpc.make_cost()
+
+        for i in range(10):
+            X = numpy.random.uniform(size=(dynamics.NUM_VELOCITY_STATES, ))
+            U = numpy.random.uniform(low=-10, high=10, size=(8, ))
+            R = numpy.random.uniform(low=-1, high=1, size=(3, ))
+
+            J = numpy.array(cost(X, U, R))[0, 0]
+            jax_J = jax_dynamics.mpc_cost(self.coefficients, X, U, R)
+
+            self.assertAlmostEqual(J, jax_J)
+
+        R = jax.numpy.array([0.0, 0.0, 1.0])
+
+        # Now try spinning in place and make sure the cost doesn't change.
+        # This tells us if we got our rotations right.
+        steer_I = numpy.array([(i % 2) * 20 for i in range(8)])
+
+        X = utils.state_vector(velocity=numpy.array([[0.0], [0.0]]),
+                               omega=0.0,
+                               module_angles=[
+                                   3 * numpy.pi / 4.0, -3 * numpy.pi / 4.0,
+                                   -numpy.pi / 4.0, numpy.pi / 4.0
+                               ],
+                               drive_wheel_velocity=1.0)
+
+        jax_J_orig = jax_dynamics.mpc_cost(self.coefficients,
+                                           self.to_velocity_state(X)[:, 0],
+                                           steer_I, R)
+
+        X_rotated = utils.state_vector(velocity=numpy.array([[0.0], [0.0]]),
+                                       omega=0.0,
+                                       theta=numpy.pi,
+                                       module_angles=[
+                                           3 * numpy.pi / 4.0,
+                                           -3 * numpy.pi / 4.0,
+                                           -numpy.pi / 4.0, numpy.pi / 4.0
+                                       ],
+                                       drive_wheel_velocity=1.0)
+        jax_J_rotated = jax_dynamics.mpc_cost(
+            self.coefficients,
+            self.to_velocity_state(X_rotated)[:, 0], steer_I, R)
+
+        self.assertAlmostEqual(jax_J_orig, jax_J_rotated)
 
     def test_cpp_consistency(self):
         """Tests that the C++ physics are consistent with the Python physics."""