Add JAX version of the physics, and tests to confirm it matches

This sets us up to very efficiently compute dynamics and jit.

Change-Id: I57aea5c1f480759c8e5e658ff6f4de0d82ef273d
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/BUILD b/frc971/control_loops/swerve/BUILD
index c7355b5..26828e4 100644
--- a/frc971/control_loops/swerve/BUILD
+++ b/frc971/control_loops/swerve/BUILD
@@ -197,14 +197,51 @@
     ],
 )
 
+py_library(
+    name = "jax_dynamics",
+    srcs = [
+        "jax_dynamics.py",
+    ],
+    deps = [
+        ":dynamics",
+        "//frc971/control_loops/python:controls",
+        "@pip//jax",
+    ],
+)
+
 py_test(
-    name = "physics_test",
+    name = "physics_test_cpu",
     srcs = [
         "physics_test.py",
     ],
+    env = {
+        "JAX_PLATFORMS": "cpu",
+    },
+    main = "physics_test.py",
     target_compatible_with = ["@platforms//cpu:x86_64"],
     deps = [
         ":dynamics",
+        ":jax_dynamics",
+        ":physics_test_utils",
+        "@pip//casadi",
+        "@pip//numpy",
+        "@pip//scipy",
+    ],
+)
+
+py_test(
+    name = "physics_test_gpu",
+    srcs = [
+        "physics_test.py",
+    ],
+    env = {
+        "JAX_PLATFORMS": "cuda",
+    },
+    main = "physics_test.py",
+    target_compatible_with = ["@platforms//cpu:x86_64"],
+    deps = [
+        ":dynamics",
+        ":jax_dynamics",
         ":physics_test_utils",
         "@pip//casadi",
         "@pip//numpy",