Add JAX version of the physics, and tests to confirm it matches
This sets us up to very efficiently compute dynamics and jit.
Change-Id: I57aea5c1f480759c8e5e658ff6f4de0d82ef273d
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/BUILD b/frc971/control_loops/swerve/BUILD
index c7355b5..26828e4 100644
--- a/frc971/control_loops/swerve/BUILD
+++ b/frc971/control_loops/swerve/BUILD
@@ -197,14 +197,51 @@
],
)
+py_library(
+ name = "jax_dynamics",
+ srcs = [
+ "jax_dynamics.py",
+ ],
+ deps = [
+ ":dynamics",
+ "//frc971/control_loops/python:controls",
+ "@pip//jax",
+ ],
+)
+
py_test(
- name = "physics_test",
+ name = "physics_test_cpu",
srcs = [
"physics_test.py",
],
+ env = {
+ "JAX_PLATFORMS": "cpu",
+ },
+ main = "physics_test.py",
target_compatible_with = ["@platforms//cpu:x86_64"],
deps = [
":dynamics",
+ ":jax_dynamics",
+ ":physics_test_utils",
+ "@pip//casadi",
+ "@pip//numpy",
+ "@pip//scipy",
+ ],
+)
+
+py_test(
+ name = "physics_test_gpu",
+ srcs = [
+ "physics_test.py",
+ ],
+ env = {
+ "JAX_PLATFORMS": "cuda",
+ },
+ main = "physics_test.py",
+ target_compatible_with = ["@platforms//cpu:x86_64"],
+ deps = [
+ ":dynamics",
+ ":jax_dynamics",
":physics_test_utils",
"@pip//casadi",
"@pip//numpy",
diff --git a/frc971/control_loops/swerve/generate_physics.cc b/frc971/control_loops/swerve/generate_physics.cc
index 3c4eaf6..e796ec6 100644
--- a/frc971/control_loops/swerve/generate_physics.cc
+++ b/frc971/control_loops/swerve/generate_physics.cc
@@ -142,7 +142,7 @@
rb1_ = symbol("rb1");
rb2_ = symbol("rb2");
- Gd2_ = symbol("Gd3");
+ Gd3_ = symbol("Gd3");
Gd_ = symbol("Gd");
Js_ = symbol("Js");
diff --git a/frc971/control_loops/swerve/jax_dynamics.py b/frc971/control_loops/swerve/jax_dynamics.py
new file mode 100644
index 0000000..8b05ef4
--- /dev/null
+++ b/frc971/control_loops/swerve/jax_dynamics.py
@@ -0,0 +1,334 @@
+#!/usr/bin/env python3
+
+from functools import partial
+from collections import namedtuple
+import jax
+
+from frc971.control_loops.swerve import dynamics
+from frc971.control_loops.python.control_loop import KrakenFOC
+
+# Note: this physics needs to match the symengine code. We have tests that
+# confirm they match in all the cases we care about.
+
+CoefficientsType = namedtuple('CoefficientsType', [
+ 'Cx',
+ 'Cy',
+ 'rw',
+ 'm',
+ 'J',
+ 'Gd1',
+ 'rs',
+ 'rp',
+ 'Gd2',
+ 'rb1',
+ 'rb2',
+ 'Gd3',
+ 'Gd',
+ 'Js',
+ 'Gs',
+ 'wb',
+ 'Jdm',
+ 'Jsm',
+ 'Kts',
+ 'Ktd',
+ 'robot_width',
+ 'caster',
+ 'contact_patch_length',
+])
+
+
+def Coefficients(
+ Cx: float = 25.0 * 9.8 / 4.0 / 0.05,
+ Cy: float = 5 * 9.8 / 0.05 / 4.0,
+ rw: float = 2 * 0.0254,
+
+ # base is 20 kg without battery
+ m: float = 25.0,
+ J: float = 6.0,
+ Gd1: float = 12.0 / 42.0,
+ rs: float = 28.0 / 20.0 / 2.0,
+ rp: float = 18.0 / 20.0 / 2.0,
+
+ # 15 / 45 bevel ratio, calculated using python script ported over to
+ # GetBevelPitchRadius(double)
+ # TODO(Justin): Use the function instead of computed constantss
+ rb1: float = 0.3805473,
+ rb2: float = 1.14164,
+ Js: float = 0.001,
+ Gs: float = 35.0 / 468.0,
+ wb: float = 0.725,
+ drive_motor=KrakenFOC(),
+ steer_motor=KrakenFOC(),
+ robot_width: float = 24.75 * 0.0254,
+ caster: float = 0.01,
+ contact_patch_length: float = 0.02,
+) -> CoefficientsType:
+
+ Gd2 = rs / rp
+ Gd3 = rb1 / rb2
+ Gd = Gd1 * Gd2 * Gd3
+
+ Jdm = drive_motor.motor_inertia
+ Jsm = steer_motor.motor_inertia
+ Kts = steer_motor.Kt
+ Ktd = drive_motor.Kt
+
+ return CoefficientsType(
+ Cx=Cx,
+ Cy=Cy,
+ rw=rw,
+ m=m,
+ J=J,
+ Gd1=Gd1,
+ rs=rs,
+ rp=rp,
+ Gd2=Gd2,
+ rb1=rb1,
+ rb2=rb2,
+ Gd3=Gd3,
+ Gd=Gd,
+ Js=Js,
+ Gs=Gs,
+ wb=wb,
+ Jdm=Jdm,
+ Jsm=Jsm,
+ Kts=Kts,
+ Ktd=Ktd,
+ robot_width=robot_width,
+ caster=caster,
+ contact_patch_length=contact_patch_length,
+ )
+
+
+def R(theta):
+ stheta = jax.numpy.sin(theta)
+ ctheta = jax.numpy.cos(theta)
+ return jax.numpy.array([[ctheta, -stheta], [stheta, ctheta]])
+
+
+def angle_cross(vector, omega):
+ return jax.numpy.array([-vector[1] * omega, vector[0] * omega])
+
+
+def force_cross(r, f):
+ return r[0] * f[1] - r[1] * f[0]
+
+
+def softsign(x, gain):
+ return -2 / (1 + jax.numpy.exp(gain * x)) + 1
+
+
+def soft_atan2(y, x):
+ kMaxLogGain = 1.0 / 0.05
+ kAbsLogGain = 1.0 / 0.01
+
+ softabs_x = x * (1.0 - 2.0 / (1 + jax.numpy.exp(kAbsLogGain * x)))
+
+ return jax.numpy.arctan2(
+ y,
+ jax.scipy.special.logsumexp(
+ jax.numpy.array([1.0, softabs_x * kMaxLogGain])) / kMaxLogGain)
+
+
+def full_module_physics(coefficients: dict, Rtheta, module_index: int,
+ mounting_location, X, U):
+ X_module = X[module_index * 4:(module_index + 1) * 4]
+ Is = U[2 * module_index + 0]
+ Id = U[2 * module_index + 1]
+
+ Rthetaplusthetas = R(X[dynamics.STATE_THETA] +
+ X_module[dynamics.STATE_THETAS0])
+
+ caster_vector = jax.numpy.array([-coefficients.caster, 0.0])
+
+ robot_velocity = X[dynamics.STATE_VX:dynamics.STATE_VY + 1]
+
+ contact_patch_velocity = (
+ angle_cross(Rtheta @ mounting_location, X[dynamics.STATE_OMEGA]) +
+ robot_velocity + angle_cross(
+ Rthetaplusthetas @ caster_vector,
+ (X[dynamics.STATE_OMEGA] + X_module[dynamics.STATE_OMEGAS0])))
+
+ wheel_ground_velocity = Rthetaplusthetas.T @ contact_patch_velocity
+
+ wheel_velocity = jax.numpy.array(
+ [coefficients.rw * X_module[dynamics.STATE_OMEGAD0], 0.0])
+
+ wheel_slip_velocity = wheel_velocity - wheel_ground_velocity
+
+ slip_angle = jax.numpy.sin(
+ -soft_atan2(wheel_ground_velocity[1], wheel_ground_velocity[0]))
+
+ slip_ratio = (coefficients.rw * X_module[dynamics.STATE_OMEGAD0] -
+ wheel_ground_velocity[0]) / jax.numpy.max(
+ jax.numpy.array(
+ [0.02, jax.numpy.abs(wheel_ground_velocity[0])]))
+
+ Fwx = coefficients.Cx * slip_ratio
+ Fwy = coefficients.Cy * slip_angle
+
+ softsign_velocity = softsign(wheel_ground_velocity[0], 100)
+
+ Ms = -Fwy * (
+ (softsign_velocity * coefficients.contact_patch_length / 3.0) +
+ coefficients.caster)
+
+ alphas = (Ms + coefficients.Kts * Is / coefficients.Gs +
+ (-coefficients.wb + (coefficients.rs + coefficients.rp) *
+ (1 - coefficients.rb1 / coefficients.rp)) *
+ (coefficients.rw / coefficients.rb2 *
+ (-Fwx))) / (coefficients.Jsm +
+ (coefficients.Js /
+ (coefficients.Gs * coefficients.Gs)))
+
+ # Then solve for alphad
+ alphad = (coefficients.rs * coefficients.Jdm * coefficients.Gd3 * alphas +
+ coefficients.rp * coefficients.Ktd * Id * coefficients.Gd -
+ coefficients.rw * coefficients.rp * coefficients.Gd * Fwx *
+ coefficients.Gd) / (coefficients.rp * coefficients.Jdm)
+
+ F = Rthetaplusthetas @ jax.numpy.array([Fwx, Fwy])
+
+ torque = force_cross(mounting_location, F)
+
+ X_dot_contribution = jax.numpy.hstack((jax.numpy.zeros(
+ (4, )), ) * (module_index) + (jax.numpy.array([
+ X_module[dynamics.STATE_OMEGAS0],
+ X_module[dynamics.STATE_OMEGAD0],
+ alphas,
+ alphad,
+ ]), ) + (jax.numpy.zeros((4, )), ) * (3 - module_index) + (
+ jax.numpy.zeros((3, )),
+ F / coefficients.m,
+ jax.numpy.array([torque / coefficients.J, 0, 0, 0]),
+ ))
+
+ return X_dot_contribution
+
+
+@partial(jax.jit, static_argnames=['coefficients'])
+def full_dynamics(coefficients: CoefficientsType, X, U):
+ Rtheta = R(X[dynamics.STATE_THETA])
+
+ module0 = full_module_physics(
+ coefficients, Rtheta, 0,
+ jax.numpy.array(
+ [coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]),
+ X, U)
+ module1 = full_module_physics(
+ coefficients, Rtheta, 1,
+ jax.numpy.array(
+ [-coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]),
+ X, U)
+ module2 = full_module_physics(
+ coefficients, Rtheta, 2,
+ jax.numpy.array(
+ [-coefficients.robot_width / 2.0,
+ -coefficients.robot_width / 2.0]), X, U)
+ module3 = full_module_physics(
+ coefficients, Rtheta, 3,
+ jax.numpy.array(
+ [coefficients.robot_width / 2.0, -coefficients.robot_width / 2.0]),
+ X, U)
+
+ X_dot = module0 + module1 + module2 + module3
+
+ X_dot = X_dot.at[dynamics.STATE_X:dynamics.STATE_THETA + 1].set(
+ jax.numpy.array([
+ X[dynamics.STATE_VX],
+ X[dynamics.STATE_VY],
+ X[dynamics.STATE_OMEGA],
+ ]))
+
+ return X_dot
+
+
+def velocity_module_physics(coefficients: dict, Rtheta, module_index: int,
+ mounting_location, X, U):
+ X_module = X[module_index * 2:(module_index + 1) * 2]
+ Is = U[2 * module_index + 0]
+ Id = U[2 * module_index + 1]
+
+ Rthetaplusthetas = R(X[dynamics.VELOCITY_STATE_THETA] +
+ X_module[dynamics.VELOCITY_STATE_THETAS0])
+
+ caster_vector = jax.numpy.array([-coefficients.caster, 0.0])
+
+ robot_velocity = X[dynamics.VELOCITY_STATE_VX:dynamics.VELOCITY_STATE_VY +
+ 1]
+
+ contact_patch_velocity = (
+ angle_cross(Rtheta @ mounting_location,
+ X[dynamics.VELOCITY_STATE_OMEGA]) + robot_velocity +
+ angle_cross(Rthetaplusthetas @ caster_vector,
+ (X[dynamics.VELOCITY_STATE_OMEGA] +
+ X_module[dynamics.VELOCITY_STATE_OMEGAS0])))
+
+ wheel_ground_velocity = Rthetaplusthetas.T @ contact_patch_velocity
+
+ slip_angle = jax.numpy.sin(
+ -soft_atan2(wheel_ground_velocity[1], wheel_ground_velocity[0]))
+
+ Fwx = (coefficients.Ktd / (coefficients.Gd * coefficients.rw)) * Id
+ Fwy = coefficients.Cy * slip_angle
+
+ softsign_velocity = softsign(wheel_ground_velocity[0], 100)
+
+ Ms = -Fwy * (
+ (softsign_velocity * coefficients.contact_patch_length / 3.0) +
+ coefficients.caster)
+
+ alphas = (Ms + coefficients.Kts * Is / coefficients.Gs +
+ (-coefficients.wb + (coefficients.rs + coefficients.rp) *
+ (1 - coefficients.rb1 / coefficients.rp)) *
+ (coefficients.rw / coefficients.rb2 *
+ (-Fwx))) / (coefficients.Jsm +
+ (coefficients.Js /
+ (coefficients.Gs * coefficients.Gs)))
+
+ F = Rthetaplusthetas @ jax.numpy.array([Fwx, Fwy])
+
+ torque = force_cross(mounting_location, F)
+
+ X_dot_contribution = jax.numpy.hstack((jax.numpy.zeros(
+ (2, )), ) * (module_index) + (jax.numpy.array([
+ X_module[dynamics.VELOCITY_STATE_OMEGAS0],
+ alphas,
+ ]), ) + (jax.numpy.zeros((2, )), ) * (3 - module_index) + (
+ jax.numpy.zeros((1, )),
+ F / coefficients.m,
+ jax.numpy.array([torque / coefficients.J]),
+ ))
+
+ return X_dot_contribution
+
+
+@partial(jax.jit, static_argnames=['coefficients'])
+def velocity_dynamics(coefficients: CoefficientsType, X, U):
+ Rtheta = R(X[dynamics.VELOCITY_STATE_THETA])
+
+ module0 = velocity_module_physics(
+ coefficients, Rtheta, 0,
+ jax.numpy.array(
+ [coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]),
+ X, U)
+ module1 = velocity_module_physics(
+ coefficients, Rtheta, 1,
+ jax.numpy.array(
+ [-coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]),
+ X, U)
+ module2 = velocity_module_physics(
+ coefficients, Rtheta, 2,
+ jax.numpy.array(
+ [-coefficients.robot_width / 2.0,
+ -coefficients.robot_width / 2.0]), X, U)
+ module3 = velocity_module_physics(
+ coefficients, Rtheta, 3,
+ jax.numpy.array(
+ [coefficients.robot_width / 2.0, -coefficients.robot_width / 2.0]),
+ X, U)
+
+ X_dot = module0 + module1 + module2 + module3
+
+ return X_dot.at[dynamics.VELOCITY_STATE_THETA].set(
+ X[dynamics.VELOCITY_STATE_OMEGA])
diff --git a/frc971/control_loops/swerve/physics_test.py b/frc971/control_loops/swerve/physics_test.py
index 0a3fdb6..af3dc17 100644
--- a/frc971/control_loops/swerve/physics_test.py
+++ b/frc971/control_loops/swerve/physics_test.py
@@ -1,6 +1,13 @@
-#!/usr/bin/python3
+#!/usr/bin/env python3
+import jax
+
+# casadi uses doubles. jax likes floats. We want to make sure the physics
+# matches really precisely, so force doubles for the tests.
+jax.config.update("jax_enable_x64", True)
import numpy
+
+numpy.set_printoptions(precision=20)
import sys, os
import casadi
import scipy
@@ -11,10 +18,12 @@
from frc971.control_loops.swerve import dynamics
from frc971.control_loops.swerve import nocaster_dynamics
from frc971.control_loops.swerve import physics_test_utils as utils
+from frc971.control_loops.swerve import jax_dynamics
class TestSwervePhysics(unittest.TestCase):
I = numpy.zeros((8, 1))
+ coefficients = jax_dynamics.Coefficients()
def to_velocity_state(self, X):
return dynamics.to_velocity_state(X)
@@ -22,19 +31,74 @@
def swerve_full_dynamics(self, X, U, skip_compare=False):
X_velocity = self.to_velocity_state(X)
Xdot = self.position_swerve_full_dynamics(X, U)
+
if not skip_compare:
- velocity_states = self.to_velocity_state(Xdot)
+ Xdot_velocity = self.to_velocity_state(Xdot)
velocity_physics = self.velocity_swerve_physics(X_velocity, U)
+
self.assertLess(
- numpy.linalg.norm(velocity_states - velocity_physics),
+ numpy.linalg.norm(Xdot_velocity - velocity_physics),
2e-2,
msg=
- f'Norm failed, full physics -> {velocity_states.T}, velocity physics -> {velocity_physics}, difference -> {velocity_physics - velocity_states}',
+ f'Norm failed, full physics -> {X_velocity.T}, velocity physics -> {velocity_physics}, difference -> {velocity_physics - X_velocity}',
)
+ self.validate_dynamics_equality(X, U)
+
return Xdot
+ def validate_dynamics_equality(self, X, U):
+ """Tests that both the JAX code and casadi code produce identical answers.
+
+ Note:
+ If the symengine code has been updated, you likely need to update the JAX
+ by hand. We had trouble code generating it with good performance.
+ """
+ X_velocity = self.to_velocity_state(X)
+
+ Xdot = self.position_swerve_full_dynamics(X, U)
+ Xdot_jax = jax_dynamics.full_dynamics(self.coefficients, X[:, 0], U[:,
+ 0])
+
+ self.assertLess(
+ numpy.linalg.norm(Xdot[:, 0] - Xdot_jax),
+ 2e-8,
+ msg=
+ f'Xdot: {Xdot[:, 0]}, Xdot_jax: {Xdot_jax}, diff: {(Xdot[:, 0] - Xdot_jax)}',
+ )
+
+ velocity_physics = self.velocity_swerve_physics(X_velocity, U)
+ velocity_physics_jax = jax_dynamics.velocity_dynamics(
+ self.coefficients, X_velocity[:, 0], U[:, 0])
+
+ self.assertLess(
+ numpy.linalg.norm(velocity_physics[:, 0] - velocity_physics_jax),
+ 2e-8,
+ msg=
+ f'Xdot: {velocity_physics[:, 0]}, Xdot_jax: {velocity_physics_jax}, diff: {(velocity_physics[:, 0] - velocity_physics_jax)}',
+ )
+
+ def wrap_and_validate(self, function, i):
+ """Wraps a function, and validates JAX and casadi agree.
+
+ We want to do it every time we check any intermediate, since the tests
+ are designed to test all the corner cases, but they don't all do it
+ through the main dynamics function above.
+ """
+ wrapped_fn = utils.wrap_module(function, i)
+
+ def do(X, U):
+ self.validate_dynamics_equality(X, U)
+ return wrapped_fn(X, U)
+
+ return do
+
def wrap(self, python_module):
+ # Only update on change to avoid re-jiting things.
+ if self.coefficients.caster != python_module.CASTER:
+ self.coefficients = self.coefficients._replace(
+ caster=python_module.CASTER)
+
self.position_swerve_full_dynamics = utils.wrap(
python_module.swerve_full_dynamics)
@@ -45,37 +109,42 @@
evaluated_fn(X, U))
self.contact_patch_velocity = [
- utils.wrap_module(python_module.contact_patch_velocity, i)
+ self.wrap_and_validate(python_module.contact_patch_velocity, i)
for i in range(4)
]
self.wheel_ground_velocity = [
- utils.wrap_module(python_module.wheel_ground_velocity, i)
+ self.wrap_and_validate(python_module.wheel_ground_velocity, i)
for i in range(4)
]
self.wheel_slip_velocity = [
- utils.wrap_module(python_module.wheel_slip_velocity, i)
+ self.wrap_and_validate(python_module.wheel_slip_velocity, i)
for i in range(4)
]
self.wheel_force = [
- utils.wrap_module(python_module.wheel_force, i) for i in range(4)
- ]
- self.module_angular_accel = [
- utils.wrap_module(python_module.module_angular_accel, i)
+ self.wrap_and_validate(python_module.wheel_force, i)
for i in range(4)
]
- self.F = [utils.wrap_module(python_module.F, i) for i in range(4)]
+ self.module_angular_accel = [
+ self.wrap_and_validate(python_module.module_angular_accel, i)
+ for i in range(4)
+ ]
+ self.F = [self.wrap_and_validate(python_module.F, i) for i in range(4)]
self.mounting_location = [
- utils.wrap_module(python_module.mounting_location, i)
+ self.wrap_and_validate(python_module.mounting_location, i)
for i in range(4)
]
self.slip_angle = [
- utils.wrap_module(python_module.slip_angle, i) for i in range(4)
+ self.wrap_and_validate(python_module.slip_angle, i)
+ for i in range(4)
]
self.slip_ratio = [
- utils.wrap_module(python_module.slip_ratio, i) for i in range(4)
+ self.wrap_and_validate(python_module.slip_ratio, i)
+ for i in range(4)
]
- self.Ms = [utils.wrap_module(python_module.Ms, i) for i in range(4)]
+ self.Ms = [
+ self.wrap_and_validate(python_module.Ms, i) for i in range(4)
+ ]
def setUp(self):
self.wrap(dynamics)
@@ -192,6 +261,15 @@
scipy.special.logsumexp([1.0, abs(vx) * loggain]) /
loggain))
+ jax_expected = jax.numpy.sin(
+ -jax_dynamics.soft_atan2(vy, vx))
+
+ self.assertAlmostEqual(
+ expected,
+ jax_expected,
+ msg=f"Trying wrap {wrap} theta {theta}",
+ )
+
self.assertAlmostEqual(
expected,
computed_angle,
diff --git a/frc971/control_loops/swerve/physics_test_utils.py b/frc971/control_loops/swerve/physics_test_utils.py
index 30b352f..99abe22 100644
--- a/frc971/control_loops/swerve/physics_test_utils.py
+++ b/frc971/control_loops/swerve/physics_test_utils.py
@@ -21,35 +21,39 @@
# All the wheels are spinning at the speed needed to hit the velocity in m/s
drive_wheel_velocity = drive_wheel_velocity or numpy.linalg.norm(velocity)
- X_initial[2, 0] = module_omega
- X_initial[3, 0] = drive_wheel_velocity / (dynamics.WHEEL_RADIUS)
+ X_initial[dynamics.STATE_OMEGAS0, 0] = module_omega
+ X_initial[dynamics.STATE_OMEGAD0,
+ 0] = drive_wheel_velocity / (dynamics.WHEEL_RADIUS)
- X_initial[6, 0] = module_omega
- X_initial[7, 0] = drive_wheel_velocity / (dynamics.WHEEL_RADIUS)
+ X_initial[dynamics.STATE_OMEGAS1, 0] = module_omega
+ X_initial[dynamics.STATE_OMEGAD1,
+ 0] = drive_wheel_velocity / (dynamics.WHEEL_RADIUS)
- X_initial[10, 0] = module_omega
- X_initial[11, 0] = drive_wheel_velocity / (dynamics.WHEEL_RADIUS)
+ X_initial[dynamics.STATE_OMEGAS2, 0] = module_omega
+ X_initial[dynamics.STATE_OMEGAD2,
+ 0] = drive_wheel_velocity / (dynamics.WHEEL_RADIUS)
- X_initial[14, 0] = module_omega
- X_initial[15, 0] = drive_wheel_velocity / (dynamics.WHEEL_RADIUS)
+ X_initial[dynamics.STATE_OMEGAS3, 0] = module_omega
+ X_initial[dynamics.STATE_OMEGAD3,
+ 0] = drive_wheel_velocity / (dynamics.WHEEL_RADIUS)
- X_initial[0, 0] = module_angle
- X_initial[4, 0] = module_angle
- X_initial[8, 0] = module_angle
- X_initial[12, 0] = module_angle
+ X_initial[dynamics.STATE_THETAS0, 0] = module_angle
+ X_initial[dynamics.STATE_THETAS1, 0] = module_angle
+ X_initial[dynamics.STATE_THETAS2, 0] = module_angle
+ X_initial[dynamics.STATE_THETAS3, 0] = module_angle
if module_angles is not None:
assert len(module_angles) == 4
- X_initial[0, 0] = module_angles[0]
- X_initial[4, 0] = module_angles[1]
- X_initial[8, 0] = module_angles[2]
- X_initial[12, 0] = module_angles[3]
+ X_initial[dynamics.STATE_THETAS0, 0] = module_angles[0]
+ X_initial[dynamics.STATE_THETAS1, 0] = module_angles[1]
+ X_initial[dynamics.STATE_THETAS2, 0] = module_angles[2]
+ X_initial[dynamics.STATE_THETAS3, 0] = module_angles[3]
- X_initial[18, 0] = theta
+ X_initial[dynamics.STATE_THETA, 0] = theta
- X_initial[19, 0] = velocity[0, 0] + dx
- X_initial[20, 0] = velocity[1, 0] + dy
- X_initial[21, 0] = omega
+ X_initial[dynamics.STATE_VX, 0] = velocity[0, 0] + dx
+ X_initial[dynamics.STATE_VY, 0] = velocity[1, 0] + dy
+ X_initial[dynamics.STATE_OMEGA, 0] = omega
return X_initial