Make JAX dynamics code not depend on casadi code

Fewer dependencies is better.

Change-Id: I60f931347d27b75038387fde11a42ee98c9c1b99
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/BUILD b/frc971/control_loops/swerve/BUILD
index 1b8dcc3..513bf1b 100644
--- a/frc971/control_loops/swerve/BUILD
+++ b/frc971/control_loops/swerve/BUILD
@@ -144,12 +144,14 @@
         "dynamics.cc",
         "dynamics.h",
         "dynamics.py",
+        "dynamics_constants.py",
     ],
     args = [
         "--output_base=$(BINDIR)/",
         "--cc_output_path=$(location :dynamics.cc)",
         "--h_output_path=$(location :dynamics.h)",
         "--casadi_py_output_path=$(location :dynamics.py)",
+        "--constants_output_path=$(location :dynamics_constants.py)",
     ],
     tool = ":generate_physics",
 )
@@ -204,10 +206,10 @@
 py_library(
     name = "jax_dynamics",
     srcs = [
+        "dynamics_constants.py",
         "jax_dynamics.py",
     ],
     deps = [
-        ":dynamics",
         "//frc971/control_loops/python:controls",
         "@pip//jax",
     ],
diff --git a/frc971/control_loops/swerve/generate_physics.cc b/frc971/control_loops/swerve/generate_physics.cc
index af058f6..7ea3db5 100644
--- a/frc971/control_loops/swerve/generate_physics.cc
+++ b/frc971/control_loops/swerve/generate_physics.cc
@@ -32,6 +32,8 @@
           "Path to write generated header code to");
 ABSL_FLAG(std::string, casadi_py_output_path, "",
           "Path to write casadi generated py code to");
+ABSL_FLAG(std::string, constants_output_path, "",
+          "Path to write constants python code to");
 ABSL_FLAG(double, caster, 0.01, "Caster in meters for the module.");
 
 ABSL_FLAG(bool, symbolic, false, "If true, write everything out symbolically.");
@@ -625,6 +627,80 @@
     }
   }
 
+  void WriteConstantsFile(std::string_view path) {
+    std::vector<std::string> result_py;
+
+    // Write out the header.
+    result_py.emplace_back("#!/usr/bin/env python3");
+    result_py.emplace_back("");
+
+    WriteConstants(&result_py);
+
+    aos::util::WriteStringToFileOrDie(path, absl::StrJoin(result_py, "\n"));
+  }
+
+  void WriteConstants(std::vector<std::string> *result_py) {
+    result_py->emplace_back(absl::Substitute("WHEEL_RADIUS = $0", ccode(*rw_)));
+    result_py->emplace_back(
+        absl::Substitute("ROBOT_WIDTH = $0", ccode(*robot_width_)));
+    result_py->emplace_back(absl::Substitute("CASTER = $0", ccode(*caster_)));
+    result_py->emplace_back("STATE_THETAS0 = 0");
+    result_py->emplace_back("STATE_THETAD0 = 1");
+    result_py->emplace_back("STATE_OMEGAS0 = 2");
+    result_py->emplace_back("STATE_OMEGAD0 = 3");
+    result_py->emplace_back("STATE_THETAS1 = 4");
+    result_py->emplace_back("STATE_THETAD1 = 5");
+    result_py->emplace_back("STATE_OMEGAS1 = 6");
+    result_py->emplace_back("STATE_OMEGAD1 = 7");
+    result_py->emplace_back("STATE_THETAS2 = 8");
+    result_py->emplace_back("STATE_THETAD2 = 9");
+    result_py->emplace_back("STATE_OMEGAS2 = 10");
+    result_py->emplace_back("STATE_OMEGAD2 = 11");
+    result_py->emplace_back("STATE_THETAS3 = 12");
+    result_py->emplace_back("STATE_THETAD3 = 13");
+    result_py->emplace_back("STATE_OMEGAS3 = 14");
+    result_py->emplace_back("STATE_OMEGAD3 = 15");
+    result_py->emplace_back("STATE_X = 16");
+    result_py->emplace_back("STATE_Y = 17");
+    result_py->emplace_back("STATE_THETA = 18");
+    result_py->emplace_back("STATE_VX = 19");
+    result_py->emplace_back("STATE_VY = 20");
+    result_py->emplace_back("STATE_OMEGA = 21");
+    result_py->emplace_back("STATE_FX = 22");
+    result_py->emplace_back("STATE_FY = 23");
+    result_py->emplace_back("STATE_MOMENT = 24");
+    result_py->emplace_back("NUM_STATES = 25");
+    result_py->emplace_back("");
+    result_py->emplace_back("VELOCITY_STATE_THETAS0 = 0");
+    result_py->emplace_back("VELOCITY_STATE_OMEGAS0 = 1");
+    result_py->emplace_back("VELOCITY_STATE_THETAS1 = 2");
+    result_py->emplace_back("VELOCITY_STATE_OMEGAS1 = 3");
+    result_py->emplace_back("VELOCITY_STATE_THETAS2 = 4");
+    result_py->emplace_back("VELOCITY_STATE_OMEGAS2 = 5");
+    result_py->emplace_back("VELOCITY_STATE_THETAS3 = 6");
+    result_py->emplace_back("VELOCITY_STATE_OMEGAS3 = 7");
+    result_py->emplace_back("VELOCITY_STATE_THETA = 8");
+    result_py->emplace_back("VELOCITY_STATE_VX = 9");
+    result_py->emplace_back("VELOCITY_STATE_VY = 10");
+    result_py->emplace_back("VELOCITY_STATE_OMEGA = 11");
+    // result_py->emplace_back("VELOCITY_STATE_FX = 16");
+    // result_py->emplace_back("VELOCITY_STATE_FY = 17");
+    // result_py->emplace_back("VELOCITY_STATE_MOMENT = 18");
+    result_py->emplace_back("NUM_VELOCITY_STATES = 12");
+    result_py->emplace_back("");
+    result_py->emplace_back("");
+    result_py->emplace_back("# Is = STEER_CURRENT_COUPLING_FACTOR * Id");
+    result_py->emplace_back(absl::Substitute(
+        "STEER_CURRENT_COUPLING_FACTOR = $0",
+        ccode(*(neg(
+            mul(div(Gs_, Kts_),
+                mul(div(Ktd_, mul(Gd_, rw_)),
+                    neg(mul(add(neg(wb_), mul(add(rs_, rp_),
+                                              sub(integer(1), div(rb1_, rp_)))),
+                            div(rw_, rb2_))))))))));
+    result_py->emplace_back("");
+  }
+
   // Writes the physics out to the provided .cc and .h path.
   void WriteCasadi(std::string_view py_path) {
     std::vector<std::string> result_py;
@@ -634,54 +710,9 @@
     result_py.emplace_back("");
     result_py.emplace_back("import casadi, numpy");
     result_py.emplace_back("");
-    result_py.emplace_back(absl::Substitute("WHEEL_RADIUS = $0", ccode(*rw_)));
-    result_py.emplace_back(
-        absl::Substitute("ROBOT_WIDTH = $0", ccode(*robot_width_)));
-    result_py.emplace_back(absl::Substitute("CASTER = $0", ccode(*caster_)));
-    result_py.emplace_back("STATE_THETAS0 = 0");
-    result_py.emplace_back("STATE_THETAD0 = 1");
-    result_py.emplace_back("STATE_OMEGAS0 = 2");
-    result_py.emplace_back("STATE_OMEGAD0 = 3");
-    result_py.emplace_back("STATE_THETAS1 = 4");
-    result_py.emplace_back("STATE_THETAD1 = 5");
-    result_py.emplace_back("STATE_OMEGAS1 = 6");
-    result_py.emplace_back("STATE_OMEGAD1 = 7");
-    result_py.emplace_back("STATE_THETAS2 = 8");
-    result_py.emplace_back("STATE_THETAD2 = 9");
-    result_py.emplace_back("STATE_OMEGAS2 = 10");
-    result_py.emplace_back("STATE_OMEGAD2 = 11");
-    result_py.emplace_back("STATE_THETAS3 = 12");
-    result_py.emplace_back("STATE_THETAD3 = 13");
-    result_py.emplace_back("STATE_OMEGAS3 = 14");
-    result_py.emplace_back("STATE_OMEGAD3 = 15");
-    result_py.emplace_back("STATE_X = 16");
-    result_py.emplace_back("STATE_Y = 17");
-    result_py.emplace_back("STATE_THETA = 18");
-    result_py.emplace_back("STATE_VX = 19");
-    result_py.emplace_back("STATE_VY = 20");
-    result_py.emplace_back("STATE_OMEGA = 21");
-    result_py.emplace_back("STATE_FX = 22");
-    result_py.emplace_back("STATE_FY = 23");
-    result_py.emplace_back("STATE_MOMENT = 24");
-    result_py.emplace_back("NUM_STATES = 25");
-    result_py.emplace_back("");
-    result_py.emplace_back("VELOCITY_STATE_THETAS0 = 0");
-    result_py.emplace_back("VELOCITY_STATE_OMEGAS0 = 1");
-    result_py.emplace_back("VELOCITY_STATE_THETAS1 = 2");
-    result_py.emplace_back("VELOCITY_STATE_OMEGAS1 = 3");
-    result_py.emplace_back("VELOCITY_STATE_THETAS2 = 4");
-    result_py.emplace_back("VELOCITY_STATE_OMEGAS2 = 5");
-    result_py.emplace_back("VELOCITY_STATE_THETAS3 = 6");
-    result_py.emplace_back("VELOCITY_STATE_OMEGAS3 = 7");
-    result_py.emplace_back("VELOCITY_STATE_THETA = 8");
-    result_py.emplace_back("VELOCITY_STATE_VX = 9");
-    result_py.emplace_back("VELOCITY_STATE_VY = 10");
-    result_py.emplace_back("VELOCITY_STATE_OMEGA = 11");
-    // result_py.emplace_back("VELOCITY_STATE_FX = 16");
-    // result_py.emplace_back("VELOCITY_STATE_FY = 17");
-    // result_py.emplace_back("VELOCITY_STATE_MOMENT = 18");
-    result_py.emplace_back("NUM_VELOCITY_STATES = 12");
-    result_py.emplace_back("");
+
+    WriteConstants(&result_py);
+
     result_py.emplace_back("def to_velocity_state(X):");
     result_py.emplace_back("    return numpy.array([");
     result_py.emplace_back("        [X[STATE_THETAS0, 0]],");
@@ -732,27 +763,6 @@
                            "casadi.exp($1.0 * x))) * $0.0]))) / $0.0)",
                            kLogGain, kAbsGain));
     }
-    result_py.emplace_back("");
-    result_py.emplace_back("# Is = STEER_CURRENT_COUPLING_FACTOR * Id");
-    result_py.emplace_back(absl::Substitute(
-        "STEER_CURRENT_COUPLING_FACTOR = $0",
-        ccode(*(neg(
-            mul(div(Gs_, Kts_),
-                mul(div(Ktd_, mul(Gd_, rw_)),
-                    neg(mul(add(neg(wb_), mul(add(rs_, rp_),
-                                              sub(integer(1), div(rb1_, rp_)))),
-                            div(rw_, rb2_))))))))));
-    result_py.emplace_back("");
-    result_py.emplace_back("# Is = STEER_CURRENT_COUPLING_FACTOR * Id");
-    result_py.emplace_back(absl::Substitute(
-        "STEER_CURRENT_COUPLING_FACTOR = $0",
-        ccode(*(neg(
-            mul(div(Gs_, Kts_),
-                mul(div(Ktd_, mul(Gd_, rw_)),
-                    neg(mul(add(neg(wb_), mul(add(rs_, rp_),
-                                              sub(integer(1), div(rb1_, rp_)))),
-                            div(rw_, rb2_))))))))));
-    result_py.emplace_back("");
 
     result_py.emplace_back("# Returns the derivative of our state vector");
     result_py.emplace_back("# [thetas0, thetad0, omegas0, omegad0,");
@@ -1284,5 +1294,9 @@
     sim.WriteCasadi(absl::GetFlag(FLAGS_casadi_py_output_path));
   }
 
+  if (!absl::GetFlag(FLAGS_constants_output_path).empty()) {
+    sim.WriteConstantsFile(absl::GetFlag(FLAGS_constants_output_path));
+  }
+
   return 0;
 }
diff --git a/frc971/control_loops/swerve/jax_dynamics.py b/frc971/control_loops/swerve/jax_dynamics.py
index dfadbda..6e1305d 100644
--- a/frc971/control_loops/swerve/jax_dynamics.py
+++ b/frc971/control_loops/swerve/jax_dynamics.py
@@ -4,8 +4,8 @@
 from collections import namedtuple
 import jax
 
-from frc971.control_loops.swerve import dynamics
 from frc971.control_loops.python.control_loop import KrakenFOC
+from frc971.control_loops.swerve.dynamics_constants import *
 
 # Note: this physics needs to match the symengine code.  We have tests that
 # confirm they match in all the cases we care about.
@@ -130,36 +130,35 @@
             jax.numpy.array([1.0, softabs_x * kMaxLogGain])) / kMaxLogGain)
 
 
-def full_module_physics(coefficients: dict, Rtheta, module_index: int,
-                        mounting_location, X, U):
+def full_module_physics(coefficients: CoefficientsType, Rtheta,
+                        module_index: int, mounting_location, X, U):
     X_module = X[module_index * 4:(module_index + 1) * 4]
     Is = U[2 * module_index + 0]
     Id = U[2 * module_index + 1]
 
-    Rthetaplusthetas = R(X[dynamics.STATE_THETA] +
-                         X_module[dynamics.STATE_THETAS0])
+    Rthetaplusthetas = R(X[STATE_THETA] + X_module[STATE_THETAS0])
 
     caster_vector = jax.numpy.array([-coefficients.caster, 0.0])
 
-    robot_velocity = X[dynamics.STATE_VX:dynamics.STATE_VY + 1]
+    robot_velocity = X[STATE_VX:STATE_VY + 1]
 
     contact_patch_velocity = (
-        angle_cross(Rtheta @ mounting_location, X[dynamics.STATE_OMEGA]) +
-        robot_velocity + angle_cross(
-            Rthetaplusthetas @ caster_vector,
-            (X[dynamics.STATE_OMEGA] + X_module[dynamics.STATE_OMEGAS0])))
+        angle_cross(Rtheta @ mounting_location, X[STATE_OMEGA]) +
+        robot_velocity +
+        angle_cross(Rthetaplusthetas @ caster_vector,
+                    (X[STATE_OMEGA] + X_module[STATE_OMEGAS0])))
 
     wheel_ground_velocity = Rthetaplusthetas.T @ contact_patch_velocity
 
     wheel_velocity = jax.numpy.array(
-        [coefficients.rw * X_module[dynamics.STATE_OMEGAD0], 0.0])
+        [coefficients.rw * X_module[STATE_OMEGAD0], 0.0])
 
     wheel_slip_velocity = wheel_velocity - wheel_ground_velocity
 
     slip_angle = jax.numpy.sin(
         -soft_atan2(wheel_ground_velocity[1], wheel_ground_velocity[0]))
 
-    slip_ratio = (coefficients.rw * X_module[dynamics.STATE_OMEGAD0] -
+    slip_ratio = (coefficients.rw * X_module[STATE_OMEGAD0] -
                   wheel_ground_velocity[0]) / jax.numpy.max(
                       jax.numpy.array(
                           [0.02, jax.numpy.abs(wheel_ground_velocity[0])]))
@@ -193,8 +192,8 @@
 
     X_dot_contribution = jax.numpy.hstack((jax.numpy.zeros(
         (4, )), ) * (module_index) + (jax.numpy.array([
-            X_module[dynamics.STATE_OMEGAS0],
-            X_module[dynamics.STATE_OMEGAD0],
+            X_module[STATE_OMEGAS0],
+            X_module[STATE_OMEGAD0],
             alphas,
             alphad,
         ]), ) + (jax.numpy.zeros((4, )), ) * (3 - module_index) + (
@@ -208,7 +207,7 @@
 
 @partial(jax.jit, static_argnames=['coefficients'])
 def full_dynamics(coefficients: CoefficientsType, X, U):
-    Rtheta = R(X[dynamics.STATE_THETA])
+    Rtheta = R(X[STATE_THETA])
 
     module0 = full_module_physics(
         coefficients, Rtheta, 0,
@@ -233,37 +232,41 @@
 
     X_dot = module0 + module1 + module2 + module3
 
-    X_dot = X_dot.at[dynamics.STATE_X:dynamics.STATE_THETA + 1].set(
+    X_dot = X_dot.at[STATE_X:STATE_THETA + 1].set(
         jax.numpy.array([
-            X[dynamics.STATE_VX],
-            X[dynamics.STATE_VY],
-            X[dynamics.STATE_OMEGA],
+            X[STATE_VX],
+            X[STATE_VY],
+            X[STATE_OMEGA],
         ]))
 
     return X_dot
 
 
-def velocity_module_physics(coefficients: dict, Rtheta, module_index: int,
-                            mounting_location, X, U):
+def velocity_module_physics(coefficients: CoefficientsType,
+                            Rtheta: jax.typing.ArrayLike, module_index: int,
+                            mounting_location: jax.typing.ArrayLike,
+                            X: jax.typing.ArrayLike, U: jax.typing.ArrayLike):
     X_module = X[module_index * 2:(module_index + 1) * 2]
     Is = U[2 * module_index + 0]
     Id = U[2 * module_index + 1]
 
-    Rthetaplusthetas = R(X[dynamics.VELOCITY_STATE_THETA] +
-                         X_module[dynamics.VELOCITY_STATE_THETAS0])
+    rotated_mounting_location = Rtheta @ mounting_location
+
+    Rthetaplusthetas = R(X[VELOCITY_STATE_THETA] +
+                         X_module[VELOCITY_STATE_THETAS0])
 
     caster_vector = jax.numpy.array([-coefficients.caster, 0.0])
 
-    robot_velocity = X[dynamics.VELOCITY_STATE_VX:dynamics.VELOCITY_STATE_VY +
-                       1]
+    robot_velocity = X[VELOCITY_STATE_VX:VELOCITY_STATE_VY + 1]
 
     contact_patch_velocity = (
-        angle_cross(Rtheta @ mounting_location,
-                    X[dynamics.VELOCITY_STATE_OMEGA]) + robot_velocity +
-        angle_cross(Rthetaplusthetas @ caster_vector,
-                    (X[dynamics.VELOCITY_STATE_OMEGA] +
-                     X_module[dynamics.VELOCITY_STATE_OMEGAS0])))
+        angle_cross(rotated_mounting_location, X[VELOCITY_STATE_OMEGA]) +
+        robot_velocity + angle_cross(
+            Rthetaplusthetas @ caster_vector,
+            (X[VELOCITY_STATE_OMEGA] + X_module[VELOCITY_STATE_OMEGAS0])))
 
+    # Velocity of the contact patch over the field projected into the direction
+    # of the wheel.
     wheel_ground_velocity = Rthetaplusthetas.T @ contact_patch_velocity
 
     slip_angle = jax.numpy.sin(
@@ -288,11 +291,11 @@
 
     F = Rthetaplusthetas @ jax.numpy.array([Fwx, Fwy])
 
-    torque = force_cross(Rtheta @ mounting_location, F)
+    torque = force_cross(rotated_mounting_location, F)
 
     X_dot_contribution = jax.numpy.hstack((jax.numpy.zeros(
         (2, )), ) * (module_index) + (jax.numpy.array([
-            X_module[dynamics.VELOCITY_STATE_OMEGAS0],
+            X_module[VELOCITY_STATE_OMEGAS0],
             alphas,
         ]), ) + (jax.numpy.zeros((2, )), ) * (3 - module_index) + (
             jax.numpy.zeros((1, )),
@@ -300,29 +303,30 @@
             jax.numpy.array([torque / coefficients.J]),
         ))
 
-    return X_dot_contribution
+    return X_dot_contribution, F, torque
 
 
 @partial(jax.jit, static_argnames=['coefficients'])
-def velocity_dynamics(coefficients: CoefficientsType, X, U):
-    Rtheta = R(X[dynamics.VELOCITY_STATE_THETA])
+def velocity_dynamics(coefficients: CoefficientsType, X: jax.typing.ArrayLike,
+                      U: jax.typing.ArrayLike):
+    Rtheta = R(X[VELOCITY_STATE_THETA])
 
-    module0 = velocity_module_physics(
+    module0, _, _ = velocity_module_physics(
         coefficients, Rtheta, 0,
         jax.numpy.array(
             [coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]),
         X, U)
-    module1 = velocity_module_physics(
+    module1, _, _ = velocity_module_physics(
         coefficients, Rtheta, 1,
         jax.numpy.array(
             [-coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]),
         X, U)
-    module2 = velocity_module_physics(
+    module2, _, _ = velocity_module_physics(
         coefficients, Rtheta, 2,
         jax.numpy.array(
             [-coefficients.robot_width / 2.0,
              -coefficients.robot_width / 2.0]), X, U)
-    module3 = velocity_module_physics(
+    module3, _, _ = velocity_module_physics(
         coefficients, Rtheta, 3,
         jax.numpy.array(
             [coefficients.robot_width / 2.0, -coefficients.robot_width / 2.0]),
@@ -330,5 +334,21 @@
 
     X_dot = module0 + module1 + module2 + module3
 
-    return X_dot.at[dynamics.VELOCITY_STATE_THETA].set(
-        X[dynamics.VELOCITY_STATE_OMEGA])
+    return X_dot.at[VELOCITY_STATE_THETA].set(X[VELOCITY_STATE_OMEGA])
+
+
+def to_velocity_state(X):
+    return jax.numpy.array([
+        X[STATE_THETAS0],
+        X[STATE_OMEGAS0],
+        X[STATE_THETAS1],
+        X[STATE_OMEGAS1],
+        X[STATE_THETAS2],
+        X[STATE_OMEGAS2],
+        X[STATE_THETAS3],
+        X[STATE_OMEGAS3],
+        X[STATE_THETA],
+        X[STATE_VX],
+        X[STATE_VY],
+        X[STATE_OMEGA],
+    ])