Add JAX version of the physics, and tests to confirm it matches

This sets us up to very efficiently compute dynamics and jit.

Change-Id: I57aea5c1f480759c8e5e658ff6f4de0d82ef273d
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/physics_test.py b/frc971/control_loops/swerve/physics_test.py
index 0a3fdb6..af3dc17 100644
--- a/frc971/control_loops/swerve/physics_test.py
+++ b/frc971/control_loops/swerve/physics_test.py
@@ -1,6 +1,13 @@
-#!/usr/bin/python3
+#!/usr/bin/env python3
+import jax
+
+# casadi uses doubles.  jax likes floats.  We want to make sure the physics
+# matches really precisely, so force doubles for the tests.
+jax.config.update("jax_enable_x64", True)
 
 import numpy
+
+numpy.set_printoptions(precision=20)
 import sys, os
 import casadi
 import scipy
@@ -11,10 +18,12 @@
 from frc971.control_loops.swerve import dynamics
 from frc971.control_loops.swerve import nocaster_dynamics
 from frc971.control_loops.swerve import physics_test_utils as utils
+from frc971.control_loops.swerve import jax_dynamics
 
 
 class TestSwervePhysics(unittest.TestCase):
     I = numpy.zeros((8, 1))
+    coefficients = jax_dynamics.Coefficients()
 
     def to_velocity_state(self, X):
         return dynamics.to_velocity_state(X)
@@ -22,19 +31,74 @@
     def swerve_full_dynamics(self, X, U, skip_compare=False):
         X_velocity = self.to_velocity_state(X)
         Xdot = self.position_swerve_full_dynamics(X, U)
+
         if not skip_compare:
-            velocity_states = self.to_velocity_state(Xdot)
+            Xdot_velocity = self.to_velocity_state(Xdot)
             velocity_physics = self.velocity_swerve_physics(X_velocity, U)
+
             self.assertLess(
-                numpy.linalg.norm(velocity_states - velocity_physics),
+                numpy.linalg.norm(Xdot_velocity - velocity_physics),
                 2e-2,
                 msg=
-                f'Norm failed, full physics -> {velocity_states.T}, velocity physics -> {velocity_physics}, difference -> {velocity_physics - velocity_states}',
+                f'Norm failed, full physics -> {X_velocity.T}, velocity physics -> {velocity_physics}, difference -> {velocity_physics - X_velocity}',
             )
 
+        self.validate_dynamics_equality(X, U)
+
         return Xdot
 
+    def validate_dynamics_equality(self, X, U):
+        """Tests that both the JAX code and casadi code produce identical answers.
+
+        Note:
+          If the symengine code has been updated, you likely need to update the JAX
+          by hand.  We had trouble code generating it with good performance.
+        """
+        X_velocity = self.to_velocity_state(X)
+
+        Xdot = self.position_swerve_full_dynamics(X, U)
+        Xdot_jax = jax_dynamics.full_dynamics(self.coefficients, X[:, 0], U[:,
+                                                                            0])
+
+        self.assertLess(
+            numpy.linalg.norm(Xdot[:, 0] - Xdot_jax),
+            2e-8,
+            msg=
+            f'Xdot: {Xdot[:, 0]}, Xdot_jax: {Xdot_jax}, diff: {(Xdot[:, 0] - Xdot_jax)}',
+        )
+
+        velocity_physics = self.velocity_swerve_physics(X_velocity, U)
+        velocity_physics_jax = jax_dynamics.velocity_dynamics(
+            self.coefficients, X_velocity[:, 0], U[:, 0])
+
+        self.assertLess(
+            numpy.linalg.norm(velocity_physics[:, 0] - velocity_physics_jax),
+            2e-8,
+            msg=
+            f'Xdot: {velocity_physics[:, 0]}, Xdot_jax: {velocity_physics_jax}, diff: {(velocity_physics[:, 0] - velocity_physics_jax)}',
+        )
+
+    def wrap_and_validate(self, function, i):
+        """Wraps a function, and validates JAX and casadi agree.
+
+        We want to do it every time we check any intermediate, since the tests
+        are designed to test all the corner cases, but they don't all do it
+        through the main dynamics function above.
+        """
+        wrapped_fn = utils.wrap_module(function, i)
+
+        def do(X, U):
+            self.validate_dynamics_equality(X, U)
+            return wrapped_fn(X, U)
+
+        return do
+
     def wrap(self, python_module):
+        # Only update on change to avoid re-jiting things.
+        if self.coefficients.caster != python_module.CASTER:
+            self.coefficients = self.coefficients._replace(
+                caster=python_module.CASTER)
+
         self.position_swerve_full_dynamics = utils.wrap(
             python_module.swerve_full_dynamics)
 
@@ -45,37 +109,42 @@
             evaluated_fn(X, U))
 
         self.contact_patch_velocity = [
-            utils.wrap_module(python_module.contact_patch_velocity, i)
+            self.wrap_and_validate(python_module.contact_patch_velocity, i)
             for i in range(4)
         ]
         self.wheel_ground_velocity = [
-            utils.wrap_module(python_module.wheel_ground_velocity, i)
+            self.wrap_and_validate(python_module.wheel_ground_velocity, i)
             for i in range(4)
         ]
         self.wheel_slip_velocity = [
-            utils.wrap_module(python_module.wheel_slip_velocity, i)
+            self.wrap_and_validate(python_module.wheel_slip_velocity, i)
             for i in range(4)
         ]
         self.wheel_force = [
-            utils.wrap_module(python_module.wheel_force, i) for i in range(4)
-        ]
-        self.module_angular_accel = [
-            utils.wrap_module(python_module.module_angular_accel, i)
+            self.wrap_and_validate(python_module.wheel_force, i)
             for i in range(4)
         ]
-        self.F = [utils.wrap_module(python_module.F, i) for i in range(4)]
+        self.module_angular_accel = [
+            self.wrap_and_validate(python_module.module_angular_accel, i)
+            for i in range(4)
+        ]
+        self.F = [self.wrap_and_validate(python_module.F, i) for i in range(4)]
         self.mounting_location = [
-            utils.wrap_module(python_module.mounting_location, i)
+            self.wrap_and_validate(python_module.mounting_location, i)
             for i in range(4)
         ]
 
         self.slip_angle = [
-            utils.wrap_module(python_module.slip_angle, i) for i in range(4)
+            self.wrap_and_validate(python_module.slip_angle, i)
+            for i in range(4)
         ]
         self.slip_ratio = [
-            utils.wrap_module(python_module.slip_ratio, i) for i in range(4)
+            self.wrap_and_validate(python_module.slip_ratio, i)
+            for i in range(4)
         ]
-        self.Ms = [utils.wrap_module(python_module.Ms, i) for i in range(4)]
+        self.Ms = [
+            self.wrap_and_validate(python_module.Ms, i) for i in range(4)
+        ]
 
     def setUp(self):
         self.wrap(dynamics)
@@ -192,6 +261,15 @@
                         scipy.special.logsumexp([1.0, abs(vx) * loggain]) /
                         loggain))
 
+                    jax_expected = jax.numpy.sin(
+                        -jax_dynamics.soft_atan2(vy, vx))
+
+                    self.assertAlmostEqual(
+                        expected,
+                        jax_expected,
+                        msg=f"Trying wrap {wrap} theta {theta}",
+                    )
+
                     self.assertAlmostEqual(
                         expected,
                         computed_angle,