Add JAX version of the physics, and tests to confirm it matches

This sets us up to very efficiently compute dynamics and jit.

Change-Id: I57aea5c1f480759c8e5e658ff6f4de0d82ef273d
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/BUILD b/frc971/control_loops/swerve/BUILD
index c7355b5..26828e4 100644
--- a/frc971/control_loops/swerve/BUILD
+++ b/frc971/control_loops/swerve/BUILD
@@ -197,14 +197,51 @@
     ],
 )
 
+py_library(
+    name = "jax_dynamics",
+    srcs = [
+        "jax_dynamics.py",
+    ],
+    deps = [
+        ":dynamics",
+        "//frc971/control_loops/python:controls",
+        "@pip//jax",
+    ],
+)
+
 py_test(
-    name = "physics_test",
+    name = "physics_test_cpu",
     srcs = [
         "physics_test.py",
     ],
+    env = {
+        "JAX_PLATFORMS": "cpu",
+    },
+    main = "physics_test.py",
     target_compatible_with = ["@platforms//cpu:x86_64"],
     deps = [
         ":dynamics",
+        ":jax_dynamics",
+        ":physics_test_utils",
+        "@pip//casadi",
+        "@pip//numpy",
+        "@pip//scipy",
+    ],
+)
+
+py_test(
+    name = "physics_test_gpu",
+    srcs = [
+        "physics_test.py",
+    ],
+    env = {
+        "JAX_PLATFORMS": "cuda",
+    },
+    main = "physics_test.py",
+    target_compatible_with = ["@platforms//cpu:x86_64"],
+    deps = [
+        ":dynamics",
+        ":jax_dynamics",
         ":physics_test_utils",
         "@pip//casadi",
         "@pip//numpy",
diff --git a/frc971/control_loops/swerve/generate_physics.cc b/frc971/control_loops/swerve/generate_physics.cc
index 3c4eaf6..e796ec6 100644
--- a/frc971/control_loops/swerve/generate_physics.cc
+++ b/frc971/control_loops/swerve/generate_physics.cc
@@ -142,7 +142,7 @@
       rb1_ = symbol("rb1");
       rb2_ = symbol("rb2");
 
-      Gd2_ = symbol("Gd3");
+      Gd3_ = symbol("Gd3");
       Gd_ = symbol("Gd");
 
       Js_ = symbol("Js");
diff --git a/frc971/control_loops/swerve/jax_dynamics.py b/frc971/control_loops/swerve/jax_dynamics.py
new file mode 100644
index 0000000..8b05ef4
--- /dev/null
+++ b/frc971/control_loops/swerve/jax_dynamics.py
@@ -0,0 +1,334 @@
+#!/usr/bin/env python3
+
+from functools import partial
+from collections import namedtuple
+import jax
+
+from frc971.control_loops.swerve import dynamics
+from frc971.control_loops.python.control_loop import KrakenFOC
+
+# Note: this physics needs to match the symengine code.  We have tests that
+# confirm they match in all the cases we care about.
+
+CoefficientsType = namedtuple('CoefficientsType', [
+    'Cx',
+    'Cy',
+    'rw',
+    'm',
+    'J',
+    'Gd1',
+    'rs',
+    'rp',
+    'Gd2',
+    'rb1',
+    'rb2',
+    'Gd3',
+    'Gd',
+    'Js',
+    'Gs',
+    'wb',
+    'Jdm',
+    'Jsm',
+    'Kts',
+    'Ktd',
+    'robot_width',
+    'caster',
+    'contact_patch_length',
+])
+
+
+def Coefficients(
+    Cx: float = 25.0 * 9.8 / 4.0 / 0.05,
+    Cy: float = 5 * 9.8 / 0.05 / 4.0,
+    rw: float = 2 * 0.0254,
+
+    # base is 20 kg without battery
+    m: float = 25.0,
+    J: float = 6.0,
+    Gd1: float = 12.0 / 42.0,
+    rs: float = 28.0 / 20.0 / 2.0,
+    rp: float = 18.0 / 20.0 / 2.0,
+
+    # 15 / 45 bevel ratio, calculated using python script ported over to
+    # GetBevelPitchRadius(double)
+    # TODO(Justin): Use the function instead of computed constantss
+    rb1: float = 0.3805473,
+    rb2: float = 1.14164,
+    Js: float = 0.001,
+    Gs: float = 35.0 / 468.0,
+    wb: float = 0.725,
+    drive_motor=KrakenFOC(),
+    steer_motor=KrakenFOC(),
+    robot_width: float = 24.75 * 0.0254,
+    caster: float = 0.01,
+    contact_patch_length: float = 0.02,
+) -> CoefficientsType:
+
+    Gd2 = rs / rp
+    Gd3 = rb1 / rb2
+    Gd = Gd1 * Gd2 * Gd3
+
+    Jdm = drive_motor.motor_inertia
+    Jsm = steer_motor.motor_inertia
+    Kts = steer_motor.Kt
+    Ktd = drive_motor.Kt
+
+    return CoefficientsType(
+        Cx=Cx,
+        Cy=Cy,
+        rw=rw,
+        m=m,
+        J=J,
+        Gd1=Gd1,
+        rs=rs,
+        rp=rp,
+        Gd2=Gd2,
+        rb1=rb1,
+        rb2=rb2,
+        Gd3=Gd3,
+        Gd=Gd,
+        Js=Js,
+        Gs=Gs,
+        wb=wb,
+        Jdm=Jdm,
+        Jsm=Jsm,
+        Kts=Kts,
+        Ktd=Ktd,
+        robot_width=robot_width,
+        caster=caster,
+        contact_patch_length=contact_patch_length,
+    )
+
+
+def R(theta):
+    stheta = jax.numpy.sin(theta)
+    ctheta = jax.numpy.cos(theta)
+    return jax.numpy.array([[ctheta, -stheta], [stheta, ctheta]])
+
+
+def angle_cross(vector, omega):
+    return jax.numpy.array([-vector[1] * omega, vector[0] * omega])
+
+
+def force_cross(r, f):
+    return r[0] * f[1] - r[1] * f[0]
+
+
+def softsign(x, gain):
+    return -2 / (1 + jax.numpy.exp(gain * x)) + 1
+
+
+def soft_atan2(y, x):
+    kMaxLogGain = 1.0 / 0.05
+    kAbsLogGain = 1.0 / 0.01
+
+    softabs_x = x * (1.0 - 2.0 / (1 + jax.numpy.exp(kAbsLogGain * x)))
+
+    return jax.numpy.arctan2(
+        y,
+        jax.scipy.special.logsumexp(
+            jax.numpy.array([1.0, softabs_x * kMaxLogGain])) / kMaxLogGain)
+
+
+def full_module_physics(coefficients: dict, Rtheta, module_index: int,
+                        mounting_location, X, U):
+    X_module = X[module_index * 4:(module_index + 1) * 4]
+    Is = U[2 * module_index + 0]
+    Id = U[2 * module_index + 1]
+
+    Rthetaplusthetas = R(X[dynamics.STATE_THETA] +
+                         X_module[dynamics.STATE_THETAS0])
+
+    caster_vector = jax.numpy.array([-coefficients.caster, 0.0])
+
+    robot_velocity = X[dynamics.STATE_VX:dynamics.STATE_VY + 1]
+
+    contact_patch_velocity = (
+        angle_cross(Rtheta @ mounting_location, X[dynamics.STATE_OMEGA]) +
+        robot_velocity + angle_cross(
+            Rthetaplusthetas @ caster_vector,
+            (X[dynamics.STATE_OMEGA] + X_module[dynamics.STATE_OMEGAS0])))
+
+    wheel_ground_velocity = Rthetaplusthetas.T @ contact_patch_velocity
+
+    wheel_velocity = jax.numpy.array(
+        [coefficients.rw * X_module[dynamics.STATE_OMEGAD0], 0.0])
+
+    wheel_slip_velocity = wheel_velocity - wheel_ground_velocity
+
+    slip_angle = jax.numpy.sin(
+        -soft_atan2(wheel_ground_velocity[1], wheel_ground_velocity[0]))
+
+    slip_ratio = (coefficients.rw * X_module[dynamics.STATE_OMEGAD0] -
+                  wheel_ground_velocity[0]) / jax.numpy.max(
+                      jax.numpy.array(
+                          [0.02, jax.numpy.abs(wheel_ground_velocity[0])]))
+
+    Fwx = coefficients.Cx * slip_ratio
+    Fwy = coefficients.Cy * slip_angle
+
+    softsign_velocity = softsign(wheel_ground_velocity[0], 100)
+
+    Ms = -Fwy * (
+        (softsign_velocity * coefficients.contact_patch_length / 3.0) +
+        coefficients.caster)
+
+    alphas = (Ms + coefficients.Kts * Is / coefficients.Gs +
+              (-coefficients.wb + (coefficients.rs + coefficients.rp) *
+               (1 - coefficients.rb1 / coefficients.rp)) *
+              (coefficients.rw / coefficients.rb2 *
+               (-Fwx))) / (coefficients.Jsm +
+                           (coefficients.Js /
+                            (coefficients.Gs * coefficients.Gs)))
+
+    # Then solve for alphad
+    alphad = (coefficients.rs * coefficients.Jdm * coefficients.Gd3 * alphas +
+              coefficients.rp * coefficients.Ktd * Id * coefficients.Gd -
+              coefficients.rw * coefficients.rp * coefficients.Gd * Fwx *
+              coefficients.Gd) / (coefficients.rp * coefficients.Jdm)
+
+    F = Rthetaplusthetas @ jax.numpy.array([Fwx, Fwy])
+
+    torque = force_cross(mounting_location, F)
+
+    X_dot_contribution = jax.numpy.hstack((jax.numpy.zeros(
+        (4, )), ) * (module_index) + (jax.numpy.array([
+            X_module[dynamics.STATE_OMEGAS0],
+            X_module[dynamics.STATE_OMEGAD0],
+            alphas,
+            alphad,
+        ]), ) + (jax.numpy.zeros((4, )), ) * (3 - module_index) + (
+            jax.numpy.zeros((3, )),
+            F / coefficients.m,
+            jax.numpy.array([torque / coefficients.J, 0, 0, 0]),
+        ))
+
+    return X_dot_contribution
+
+
+@partial(jax.jit, static_argnames=['coefficients'])
+def full_dynamics(coefficients: CoefficientsType, X, U):
+    Rtheta = R(X[dynamics.STATE_THETA])
+
+    module0 = full_module_physics(
+        coefficients, Rtheta, 0,
+        jax.numpy.array(
+            [coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]),
+        X, U)
+    module1 = full_module_physics(
+        coefficients, Rtheta, 1,
+        jax.numpy.array(
+            [-coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]),
+        X, U)
+    module2 = full_module_physics(
+        coefficients, Rtheta, 2,
+        jax.numpy.array(
+            [-coefficients.robot_width / 2.0,
+             -coefficients.robot_width / 2.0]), X, U)
+    module3 = full_module_physics(
+        coefficients, Rtheta, 3,
+        jax.numpy.array(
+            [coefficients.robot_width / 2.0, -coefficients.robot_width / 2.0]),
+        X, U)
+
+    X_dot = module0 + module1 + module2 + module3
+
+    X_dot = X_dot.at[dynamics.STATE_X:dynamics.STATE_THETA + 1].set(
+        jax.numpy.array([
+            X[dynamics.STATE_VX],
+            X[dynamics.STATE_VY],
+            X[dynamics.STATE_OMEGA],
+        ]))
+
+    return X_dot
+
+
+def velocity_module_physics(coefficients: dict, Rtheta, module_index: int,
+                            mounting_location, X, U):
+    X_module = X[module_index * 2:(module_index + 1) * 2]
+    Is = U[2 * module_index + 0]
+    Id = U[2 * module_index + 1]
+
+    Rthetaplusthetas = R(X[dynamics.VELOCITY_STATE_THETA] +
+                         X_module[dynamics.VELOCITY_STATE_THETAS0])
+
+    caster_vector = jax.numpy.array([-coefficients.caster, 0.0])
+
+    robot_velocity = X[dynamics.VELOCITY_STATE_VX:dynamics.VELOCITY_STATE_VY +
+                       1]
+
+    contact_patch_velocity = (
+        angle_cross(Rtheta @ mounting_location,
+                    X[dynamics.VELOCITY_STATE_OMEGA]) + robot_velocity +
+        angle_cross(Rthetaplusthetas @ caster_vector,
+                    (X[dynamics.VELOCITY_STATE_OMEGA] +
+                     X_module[dynamics.VELOCITY_STATE_OMEGAS0])))
+
+    wheel_ground_velocity = Rthetaplusthetas.T @ contact_patch_velocity
+
+    slip_angle = jax.numpy.sin(
+        -soft_atan2(wheel_ground_velocity[1], wheel_ground_velocity[0]))
+
+    Fwx = (coefficients.Ktd / (coefficients.Gd * coefficients.rw)) * Id
+    Fwy = coefficients.Cy * slip_angle
+
+    softsign_velocity = softsign(wheel_ground_velocity[0], 100)
+
+    Ms = -Fwy * (
+        (softsign_velocity * coefficients.contact_patch_length / 3.0) +
+        coefficients.caster)
+
+    alphas = (Ms + coefficients.Kts * Is / coefficients.Gs +
+              (-coefficients.wb + (coefficients.rs + coefficients.rp) *
+               (1 - coefficients.rb1 / coefficients.rp)) *
+              (coefficients.rw / coefficients.rb2 *
+               (-Fwx))) / (coefficients.Jsm +
+                           (coefficients.Js /
+                            (coefficients.Gs * coefficients.Gs)))
+
+    F = Rthetaplusthetas @ jax.numpy.array([Fwx, Fwy])
+
+    torque = force_cross(mounting_location, F)
+
+    X_dot_contribution = jax.numpy.hstack((jax.numpy.zeros(
+        (2, )), ) * (module_index) + (jax.numpy.array([
+            X_module[dynamics.VELOCITY_STATE_OMEGAS0],
+            alphas,
+        ]), ) + (jax.numpy.zeros((2, )), ) * (3 - module_index) + (
+            jax.numpy.zeros((1, )),
+            F / coefficients.m,
+            jax.numpy.array([torque / coefficients.J]),
+        ))
+
+    return X_dot_contribution
+
+
+@partial(jax.jit, static_argnames=['coefficients'])
+def velocity_dynamics(coefficients: CoefficientsType, X, U):
+    Rtheta = R(X[dynamics.VELOCITY_STATE_THETA])
+
+    module0 = velocity_module_physics(
+        coefficients, Rtheta, 0,
+        jax.numpy.array(
+            [coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]),
+        X, U)
+    module1 = velocity_module_physics(
+        coefficients, Rtheta, 1,
+        jax.numpy.array(
+            [-coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]),
+        X, U)
+    module2 = velocity_module_physics(
+        coefficients, Rtheta, 2,
+        jax.numpy.array(
+            [-coefficients.robot_width / 2.0,
+             -coefficients.robot_width / 2.0]), X, U)
+    module3 = velocity_module_physics(
+        coefficients, Rtheta, 3,
+        jax.numpy.array(
+            [coefficients.robot_width / 2.0, -coefficients.robot_width / 2.0]),
+        X, U)
+
+    X_dot = module0 + module1 + module2 + module3
+
+    return X_dot.at[dynamics.VELOCITY_STATE_THETA].set(
+        X[dynamics.VELOCITY_STATE_OMEGA])
diff --git a/frc971/control_loops/swerve/physics_test.py b/frc971/control_loops/swerve/physics_test.py
index 0a3fdb6..af3dc17 100644
--- a/frc971/control_loops/swerve/physics_test.py
+++ b/frc971/control_loops/swerve/physics_test.py
@@ -1,6 +1,13 @@
-#!/usr/bin/python3
+#!/usr/bin/env python3
+import jax
+
+# casadi uses doubles.  jax likes floats.  We want to make sure the physics
+# matches really precisely, so force doubles for the tests.
+jax.config.update("jax_enable_x64", True)
 
 import numpy
+
+numpy.set_printoptions(precision=20)
 import sys, os
 import casadi
 import scipy
@@ -11,10 +18,12 @@
 from frc971.control_loops.swerve import dynamics
 from frc971.control_loops.swerve import nocaster_dynamics
 from frc971.control_loops.swerve import physics_test_utils as utils
+from frc971.control_loops.swerve import jax_dynamics
 
 
 class TestSwervePhysics(unittest.TestCase):
     I = numpy.zeros((8, 1))
+    coefficients = jax_dynamics.Coefficients()
 
     def to_velocity_state(self, X):
         return dynamics.to_velocity_state(X)
@@ -22,19 +31,74 @@
     def swerve_full_dynamics(self, X, U, skip_compare=False):
         X_velocity = self.to_velocity_state(X)
         Xdot = self.position_swerve_full_dynamics(X, U)
+
         if not skip_compare:
-            velocity_states = self.to_velocity_state(Xdot)
+            Xdot_velocity = self.to_velocity_state(Xdot)
             velocity_physics = self.velocity_swerve_physics(X_velocity, U)
+
             self.assertLess(
-                numpy.linalg.norm(velocity_states - velocity_physics),
+                numpy.linalg.norm(Xdot_velocity - velocity_physics),
                 2e-2,
                 msg=
-                f'Norm failed, full physics -> {velocity_states.T}, velocity physics -> {velocity_physics}, difference -> {velocity_physics - velocity_states}',
+                f'Norm failed, full physics -> {X_velocity.T}, velocity physics -> {velocity_physics}, difference -> {velocity_physics - X_velocity}',
             )
 
+        self.validate_dynamics_equality(X, U)
+
         return Xdot
 
+    def validate_dynamics_equality(self, X, U):
+        """Tests that both the JAX code and casadi code produce identical answers.
+
+        Note:
+          If the symengine code has been updated, you likely need to update the JAX
+          by hand.  We had trouble code generating it with good performance.
+        """
+        X_velocity = self.to_velocity_state(X)
+
+        Xdot = self.position_swerve_full_dynamics(X, U)
+        Xdot_jax = jax_dynamics.full_dynamics(self.coefficients, X[:, 0], U[:,
+                                                                            0])
+
+        self.assertLess(
+            numpy.linalg.norm(Xdot[:, 0] - Xdot_jax),
+            2e-8,
+            msg=
+            f'Xdot: {Xdot[:, 0]}, Xdot_jax: {Xdot_jax}, diff: {(Xdot[:, 0] - Xdot_jax)}',
+        )
+
+        velocity_physics = self.velocity_swerve_physics(X_velocity, U)
+        velocity_physics_jax = jax_dynamics.velocity_dynamics(
+            self.coefficients, X_velocity[:, 0], U[:, 0])
+
+        self.assertLess(
+            numpy.linalg.norm(velocity_physics[:, 0] - velocity_physics_jax),
+            2e-8,
+            msg=
+            f'Xdot: {velocity_physics[:, 0]}, Xdot_jax: {velocity_physics_jax}, diff: {(velocity_physics[:, 0] - velocity_physics_jax)}',
+        )
+
+    def wrap_and_validate(self, function, i):
+        """Wraps a function, and validates JAX and casadi agree.
+
+        We want to do it every time we check any intermediate, since the tests
+        are designed to test all the corner cases, but they don't all do it
+        through the main dynamics function above.
+        """
+        wrapped_fn = utils.wrap_module(function, i)
+
+        def do(X, U):
+            self.validate_dynamics_equality(X, U)
+            return wrapped_fn(X, U)
+
+        return do
+
     def wrap(self, python_module):
+        # Only update on change to avoid re-jiting things.
+        if self.coefficients.caster != python_module.CASTER:
+            self.coefficients = self.coefficients._replace(
+                caster=python_module.CASTER)
+
         self.position_swerve_full_dynamics = utils.wrap(
             python_module.swerve_full_dynamics)
 
@@ -45,37 +109,42 @@
             evaluated_fn(X, U))
 
         self.contact_patch_velocity = [
-            utils.wrap_module(python_module.contact_patch_velocity, i)
+            self.wrap_and_validate(python_module.contact_patch_velocity, i)
             for i in range(4)
         ]
         self.wheel_ground_velocity = [
-            utils.wrap_module(python_module.wheel_ground_velocity, i)
+            self.wrap_and_validate(python_module.wheel_ground_velocity, i)
             for i in range(4)
         ]
         self.wheel_slip_velocity = [
-            utils.wrap_module(python_module.wheel_slip_velocity, i)
+            self.wrap_and_validate(python_module.wheel_slip_velocity, i)
             for i in range(4)
         ]
         self.wheel_force = [
-            utils.wrap_module(python_module.wheel_force, i) for i in range(4)
-        ]
-        self.module_angular_accel = [
-            utils.wrap_module(python_module.module_angular_accel, i)
+            self.wrap_and_validate(python_module.wheel_force, i)
             for i in range(4)
         ]
-        self.F = [utils.wrap_module(python_module.F, i) for i in range(4)]
+        self.module_angular_accel = [
+            self.wrap_and_validate(python_module.module_angular_accel, i)
+            for i in range(4)
+        ]
+        self.F = [self.wrap_and_validate(python_module.F, i) for i in range(4)]
         self.mounting_location = [
-            utils.wrap_module(python_module.mounting_location, i)
+            self.wrap_and_validate(python_module.mounting_location, i)
             for i in range(4)
         ]
 
         self.slip_angle = [
-            utils.wrap_module(python_module.slip_angle, i) for i in range(4)
+            self.wrap_and_validate(python_module.slip_angle, i)
+            for i in range(4)
         ]
         self.slip_ratio = [
-            utils.wrap_module(python_module.slip_ratio, i) for i in range(4)
+            self.wrap_and_validate(python_module.slip_ratio, i)
+            for i in range(4)
         ]
-        self.Ms = [utils.wrap_module(python_module.Ms, i) for i in range(4)]
+        self.Ms = [
+            self.wrap_and_validate(python_module.Ms, i) for i in range(4)
+        ]
 
     def setUp(self):
         self.wrap(dynamics)
@@ -192,6 +261,15 @@
                         scipy.special.logsumexp([1.0, abs(vx) * loggain]) /
                         loggain))
 
+                    jax_expected = jax.numpy.sin(
+                        -jax_dynamics.soft_atan2(vy, vx))
+
+                    self.assertAlmostEqual(
+                        expected,
+                        jax_expected,
+                        msg=f"Trying wrap {wrap} theta {theta}",
+                    )
+
                     self.assertAlmostEqual(
                         expected,
                         computed_angle,
diff --git a/frc971/control_loops/swerve/physics_test_utils.py b/frc971/control_loops/swerve/physics_test_utils.py
index 30b352f..99abe22 100644
--- a/frc971/control_loops/swerve/physics_test_utils.py
+++ b/frc971/control_loops/swerve/physics_test_utils.py
@@ -21,35 +21,39 @@
     # All the wheels are spinning at the speed needed to hit the velocity in m/s
     drive_wheel_velocity = drive_wheel_velocity or numpy.linalg.norm(velocity)
 
-    X_initial[2, 0] = module_omega
-    X_initial[3, 0] = drive_wheel_velocity / (dynamics.WHEEL_RADIUS)
+    X_initial[dynamics.STATE_OMEGAS0, 0] = module_omega
+    X_initial[dynamics.STATE_OMEGAD0,
+              0] = drive_wheel_velocity / (dynamics.WHEEL_RADIUS)
 
-    X_initial[6, 0] = module_omega
-    X_initial[7, 0] = drive_wheel_velocity / (dynamics.WHEEL_RADIUS)
+    X_initial[dynamics.STATE_OMEGAS1, 0] = module_omega
+    X_initial[dynamics.STATE_OMEGAD1,
+              0] = drive_wheel_velocity / (dynamics.WHEEL_RADIUS)
 
-    X_initial[10, 0] = module_omega
-    X_initial[11, 0] = drive_wheel_velocity / (dynamics.WHEEL_RADIUS)
+    X_initial[dynamics.STATE_OMEGAS2, 0] = module_omega
+    X_initial[dynamics.STATE_OMEGAD2,
+              0] = drive_wheel_velocity / (dynamics.WHEEL_RADIUS)
 
-    X_initial[14, 0] = module_omega
-    X_initial[15, 0] = drive_wheel_velocity / (dynamics.WHEEL_RADIUS)
+    X_initial[dynamics.STATE_OMEGAS3, 0] = module_omega
+    X_initial[dynamics.STATE_OMEGAD3,
+              0] = drive_wheel_velocity / (dynamics.WHEEL_RADIUS)
 
-    X_initial[0, 0] = module_angle
-    X_initial[4, 0] = module_angle
-    X_initial[8, 0] = module_angle
-    X_initial[12, 0] = module_angle
+    X_initial[dynamics.STATE_THETAS0, 0] = module_angle
+    X_initial[dynamics.STATE_THETAS1, 0] = module_angle
+    X_initial[dynamics.STATE_THETAS2, 0] = module_angle
+    X_initial[dynamics.STATE_THETAS3, 0] = module_angle
 
     if module_angles is not None:
         assert len(module_angles) == 4
-        X_initial[0, 0] = module_angles[0]
-        X_initial[4, 0] = module_angles[1]
-        X_initial[8, 0] = module_angles[2]
-        X_initial[12, 0] = module_angles[3]
+        X_initial[dynamics.STATE_THETAS0, 0] = module_angles[0]
+        X_initial[dynamics.STATE_THETAS1, 0] = module_angles[1]
+        X_initial[dynamics.STATE_THETAS2, 0] = module_angles[2]
+        X_initial[dynamics.STATE_THETAS3, 0] = module_angles[3]
 
-    X_initial[18, 0] = theta
+    X_initial[dynamics.STATE_THETA, 0] = theta
 
-    X_initial[19, 0] = velocity[0, 0] + dx
-    X_initial[20, 0] = velocity[1, 0] + dy
-    X_initial[21, 0] = omega
+    X_initial[dynamics.STATE_VX, 0] = velocity[0, 0] + dx
+    X_initial[dynamics.STATE_VY, 0] = velocity[1, 0] + dy
+    X_initial[dynamics.STATE_OMEGA, 0] = omega
 
     return X_initial