Add JAX version of the physics, and tests to confirm it matches
This sets us up to very efficiently compute dynamics and jit.
Change-Id: I57aea5c1f480759c8e5e658ff6f4de0d82ef273d
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/BUILD b/frc971/control_loops/swerve/BUILD
index c7355b5..26828e4 100644
--- a/frc971/control_loops/swerve/BUILD
+++ b/frc971/control_loops/swerve/BUILD
@@ -197,14 +197,51 @@
],
)
+py_library(
+ name = "jax_dynamics",
+ srcs = [
+ "jax_dynamics.py",
+ ],
+ deps = [
+ ":dynamics",
+ "//frc971/control_loops/python:controls",
+ "@pip//jax",
+ ],
+)
+
py_test(
- name = "physics_test",
+ name = "physics_test_cpu",
srcs = [
"physics_test.py",
],
+ env = {
+ "JAX_PLATFORMS": "cpu",
+ },
+ main = "physics_test.py",
target_compatible_with = ["@platforms//cpu:x86_64"],
deps = [
":dynamics",
+ ":jax_dynamics",
+ ":physics_test_utils",
+ "@pip//casadi",
+ "@pip//numpy",
+ "@pip//scipy",
+ ],
+)
+
+py_test(
+ name = "physics_test_gpu",
+ srcs = [
+ "physics_test.py",
+ ],
+ env = {
+ "JAX_PLATFORMS": "cuda",
+ },
+ main = "physics_test.py",
+ target_compatible_with = ["@platforms//cpu:x86_64"],
+ deps = [
+ ":dynamics",
+ ":jax_dynamics",
":physics_test_utils",
"@pip//casadi",
"@pip//numpy",