Add JAX version of the physics, and tests to confirm it matches
This sets us up to very efficiently compute dynamics and jit.
Change-Id: I57aea5c1f480759c8e5e658ff6f4de0d82ef273d
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/physics_test.py b/frc971/control_loops/swerve/physics_test.py
index 0a3fdb6..af3dc17 100644
--- a/frc971/control_loops/swerve/physics_test.py
+++ b/frc971/control_loops/swerve/physics_test.py
@@ -1,6 +1,13 @@
-#!/usr/bin/python3
+#!/usr/bin/env python3
+import jax
+
+# casadi uses doubles. jax likes floats. We want to make sure the physics
+# matches really precisely, so force doubles for the tests.
+jax.config.update("jax_enable_x64", True)
import numpy
+
+numpy.set_printoptions(precision=20)
import sys, os
import casadi
import scipy
@@ -11,10 +18,12 @@
from frc971.control_loops.swerve import dynamics
from frc971.control_loops.swerve import nocaster_dynamics
from frc971.control_loops.swerve import physics_test_utils as utils
+from frc971.control_loops.swerve import jax_dynamics
class TestSwervePhysics(unittest.TestCase):
I = numpy.zeros((8, 1))
+ coefficients = jax_dynamics.Coefficients()
def to_velocity_state(self, X):
return dynamics.to_velocity_state(X)
@@ -22,19 +31,74 @@
def swerve_full_dynamics(self, X, U, skip_compare=False):
X_velocity = self.to_velocity_state(X)
Xdot = self.position_swerve_full_dynamics(X, U)
+
if not skip_compare:
- velocity_states = self.to_velocity_state(Xdot)
+ Xdot_velocity = self.to_velocity_state(Xdot)
velocity_physics = self.velocity_swerve_physics(X_velocity, U)
+
self.assertLess(
- numpy.linalg.norm(velocity_states - velocity_physics),
+ numpy.linalg.norm(Xdot_velocity - velocity_physics),
2e-2,
msg=
- f'Norm failed, full physics -> {velocity_states.T}, velocity physics -> {velocity_physics}, difference -> {velocity_physics - velocity_states}',
+ f'Norm failed, full physics -> {X_velocity.T}, velocity physics -> {velocity_physics}, difference -> {velocity_physics - X_velocity}',
)
+ self.validate_dynamics_equality(X, U)
+
return Xdot
+ def validate_dynamics_equality(self, X, U):
+ """Tests that both the JAX code and casadi code produce identical answers.
+
+ Note:
+ If the symengine code has been updated, you likely need to update the JAX
+ by hand. We had trouble code generating it with good performance.
+ """
+ X_velocity = self.to_velocity_state(X)
+
+ Xdot = self.position_swerve_full_dynamics(X, U)
+ Xdot_jax = jax_dynamics.full_dynamics(self.coefficients, X[:, 0], U[:,
+ 0])
+
+ self.assertLess(
+ numpy.linalg.norm(Xdot[:, 0] - Xdot_jax),
+ 2e-8,
+ msg=
+ f'Xdot: {Xdot[:, 0]}, Xdot_jax: {Xdot_jax}, diff: {(Xdot[:, 0] - Xdot_jax)}',
+ )
+
+ velocity_physics = self.velocity_swerve_physics(X_velocity, U)
+ velocity_physics_jax = jax_dynamics.velocity_dynamics(
+ self.coefficients, X_velocity[:, 0], U[:, 0])
+
+ self.assertLess(
+ numpy.linalg.norm(velocity_physics[:, 0] - velocity_physics_jax),
+ 2e-8,
+ msg=
+ f'Xdot: {velocity_physics[:, 0]}, Xdot_jax: {velocity_physics_jax}, diff: {(velocity_physics[:, 0] - velocity_physics_jax)}',
+ )
+
+ def wrap_and_validate(self, function, i):
+ """Wraps a function, and validates JAX and casadi agree.
+
+ We want to do it every time we check any intermediate, since the tests
+ are designed to test all the corner cases, but they don't all do it
+ through the main dynamics function above.
+ """
+ wrapped_fn = utils.wrap_module(function, i)
+
+ def do(X, U):
+ self.validate_dynamics_equality(X, U)
+ return wrapped_fn(X, U)
+
+ return do
+
def wrap(self, python_module):
+ # Only update on change to avoid re-jiting things.
+ if self.coefficients.caster != python_module.CASTER:
+ self.coefficients = self.coefficients._replace(
+ caster=python_module.CASTER)
+
self.position_swerve_full_dynamics = utils.wrap(
python_module.swerve_full_dynamics)
@@ -45,37 +109,42 @@
evaluated_fn(X, U))
self.contact_patch_velocity = [
- utils.wrap_module(python_module.contact_patch_velocity, i)
+ self.wrap_and_validate(python_module.contact_patch_velocity, i)
for i in range(4)
]
self.wheel_ground_velocity = [
- utils.wrap_module(python_module.wheel_ground_velocity, i)
+ self.wrap_and_validate(python_module.wheel_ground_velocity, i)
for i in range(4)
]
self.wheel_slip_velocity = [
- utils.wrap_module(python_module.wheel_slip_velocity, i)
+ self.wrap_and_validate(python_module.wheel_slip_velocity, i)
for i in range(4)
]
self.wheel_force = [
- utils.wrap_module(python_module.wheel_force, i) for i in range(4)
- ]
- self.module_angular_accel = [
- utils.wrap_module(python_module.module_angular_accel, i)
+ self.wrap_and_validate(python_module.wheel_force, i)
for i in range(4)
]
- self.F = [utils.wrap_module(python_module.F, i) for i in range(4)]
+ self.module_angular_accel = [
+ self.wrap_and_validate(python_module.module_angular_accel, i)
+ for i in range(4)
+ ]
+ self.F = [self.wrap_and_validate(python_module.F, i) for i in range(4)]
self.mounting_location = [
- utils.wrap_module(python_module.mounting_location, i)
+ self.wrap_and_validate(python_module.mounting_location, i)
for i in range(4)
]
self.slip_angle = [
- utils.wrap_module(python_module.slip_angle, i) for i in range(4)
+ self.wrap_and_validate(python_module.slip_angle, i)
+ for i in range(4)
]
self.slip_ratio = [
- utils.wrap_module(python_module.slip_ratio, i) for i in range(4)
+ self.wrap_and_validate(python_module.slip_ratio, i)
+ for i in range(4)
]
- self.Ms = [utils.wrap_module(python_module.Ms, i) for i in range(4)]
+ self.Ms = [
+ self.wrap_and_validate(python_module.Ms, i) for i in range(4)
+ ]
def setUp(self):
self.wrap(dynamics)
@@ -192,6 +261,15 @@
scipy.special.logsumexp([1.0, abs(vx) * loggain]) /
loggain))
+ jax_expected = jax.numpy.sin(
+ -jax_dynamics.soft_atan2(vy, vx))
+
+ self.assertAlmostEqual(
+ expected,
+ jax_expected,
+ msg=f"Trying wrap {wrap} theta {theta}",
+ )
+
self.assertAlmostEqual(
expected,
computed_angle,