Add JAX version of the physics, and tests to confirm it matches
This sets us up to very efficiently compute dynamics and jit.
Change-Id: I57aea5c1f480759c8e5e658ff6f4de0d82ef273d
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/jax_dynamics.py b/frc971/control_loops/swerve/jax_dynamics.py
new file mode 100644
index 0000000..8b05ef4
--- /dev/null
+++ b/frc971/control_loops/swerve/jax_dynamics.py
@@ -0,0 +1,334 @@
+#!/usr/bin/env python3
+
+from functools import partial
+from collections import namedtuple
+import jax
+
+from frc971.control_loops.swerve import dynamics
+from frc971.control_loops.python.control_loop import KrakenFOC
+
+# Note: this physics needs to match the symengine code. We have tests that
+# confirm they match in all the cases we care about.
+
+CoefficientsType = namedtuple('CoefficientsType', [
+ 'Cx',
+ 'Cy',
+ 'rw',
+ 'm',
+ 'J',
+ 'Gd1',
+ 'rs',
+ 'rp',
+ 'Gd2',
+ 'rb1',
+ 'rb2',
+ 'Gd3',
+ 'Gd',
+ 'Js',
+ 'Gs',
+ 'wb',
+ 'Jdm',
+ 'Jsm',
+ 'Kts',
+ 'Ktd',
+ 'robot_width',
+ 'caster',
+ 'contact_patch_length',
+])
+
+
+def Coefficients(
+ Cx: float = 25.0 * 9.8 / 4.0 / 0.05,
+ Cy: float = 5 * 9.8 / 0.05 / 4.0,
+ rw: float = 2 * 0.0254,
+
+ # base is 20 kg without battery
+ m: float = 25.0,
+ J: float = 6.0,
+ Gd1: float = 12.0 / 42.0,
+ rs: float = 28.0 / 20.0 / 2.0,
+ rp: float = 18.0 / 20.0 / 2.0,
+
+ # 15 / 45 bevel ratio, calculated using python script ported over to
+ # GetBevelPitchRadius(double)
+ # TODO(Justin): Use the function instead of computed constantss
+ rb1: float = 0.3805473,
+ rb2: float = 1.14164,
+ Js: float = 0.001,
+ Gs: float = 35.0 / 468.0,
+ wb: float = 0.725,
+ drive_motor=KrakenFOC(),
+ steer_motor=KrakenFOC(),
+ robot_width: float = 24.75 * 0.0254,
+ caster: float = 0.01,
+ contact_patch_length: float = 0.02,
+) -> CoefficientsType:
+
+ Gd2 = rs / rp
+ Gd3 = rb1 / rb2
+ Gd = Gd1 * Gd2 * Gd3
+
+ Jdm = drive_motor.motor_inertia
+ Jsm = steer_motor.motor_inertia
+ Kts = steer_motor.Kt
+ Ktd = drive_motor.Kt
+
+ return CoefficientsType(
+ Cx=Cx,
+ Cy=Cy,
+ rw=rw,
+ m=m,
+ J=J,
+ Gd1=Gd1,
+ rs=rs,
+ rp=rp,
+ Gd2=Gd2,
+ rb1=rb1,
+ rb2=rb2,
+ Gd3=Gd3,
+ Gd=Gd,
+ Js=Js,
+ Gs=Gs,
+ wb=wb,
+ Jdm=Jdm,
+ Jsm=Jsm,
+ Kts=Kts,
+ Ktd=Ktd,
+ robot_width=robot_width,
+ caster=caster,
+ contact_patch_length=contact_patch_length,
+ )
+
+
+def R(theta):
+ stheta = jax.numpy.sin(theta)
+ ctheta = jax.numpy.cos(theta)
+ return jax.numpy.array([[ctheta, -stheta], [stheta, ctheta]])
+
+
+def angle_cross(vector, omega):
+ return jax.numpy.array([-vector[1] * omega, vector[0] * omega])
+
+
+def force_cross(r, f):
+ return r[0] * f[1] - r[1] * f[0]
+
+
+def softsign(x, gain):
+ return -2 / (1 + jax.numpy.exp(gain * x)) + 1
+
+
+def soft_atan2(y, x):
+ kMaxLogGain = 1.0 / 0.05
+ kAbsLogGain = 1.0 / 0.01
+
+ softabs_x = x * (1.0 - 2.0 / (1 + jax.numpy.exp(kAbsLogGain * x)))
+
+ return jax.numpy.arctan2(
+ y,
+ jax.scipy.special.logsumexp(
+ jax.numpy.array([1.0, softabs_x * kMaxLogGain])) / kMaxLogGain)
+
+
+def full_module_physics(coefficients: dict, Rtheta, module_index: int,
+ mounting_location, X, U):
+ X_module = X[module_index * 4:(module_index + 1) * 4]
+ Is = U[2 * module_index + 0]
+ Id = U[2 * module_index + 1]
+
+ Rthetaplusthetas = R(X[dynamics.STATE_THETA] +
+ X_module[dynamics.STATE_THETAS0])
+
+ caster_vector = jax.numpy.array([-coefficients.caster, 0.0])
+
+ robot_velocity = X[dynamics.STATE_VX:dynamics.STATE_VY + 1]
+
+ contact_patch_velocity = (
+ angle_cross(Rtheta @ mounting_location, X[dynamics.STATE_OMEGA]) +
+ robot_velocity + angle_cross(
+ Rthetaplusthetas @ caster_vector,
+ (X[dynamics.STATE_OMEGA] + X_module[dynamics.STATE_OMEGAS0])))
+
+ wheel_ground_velocity = Rthetaplusthetas.T @ contact_patch_velocity
+
+ wheel_velocity = jax.numpy.array(
+ [coefficients.rw * X_module[dynamics.STATE_OMEGAD0], 0.0])
+
+ wheel_slip_velocity = wheel_velocity - wheel_ground_velocity
+
+ slip_angle = jax.numpy.sin(
+ -soft_atan2(wheel_ground_velocity[1], wheel_ground_velocity[0]))
+
+ slip_ratio = (coefficients.rw * X_module[dynamics.STATE_OMEGAD0] -
+ wheel_ground_velocity[0]) / jax.numpy.max(
+ jax.numpy.array(
+ [0.02, jax.numpy.abs(wheel_ground_velocity[0])]))
+
+ Fwx = coefficients.Cx * slip_ratio
+ Fwy = coefficients.Cy * slip_angle
+
+ softsign_velocity = softsign(wheel_ground_velocity[0], 100)
+
+ Ms = -Fwy * (
+ (softsign_velocity * coefficients.contact_patch_length / 3.0) +
+ coefficients.caster)
+
+ alphas = (Ms + coefficients.Kts * Is / coefficients.Gs +
+ (-coefficients.wb + (coefficients.rs + coefficients.rp) *
+ (1 - coefficients.rb1 / coefficients.rp)) *
+ (coefficients.rw / coefficients.rb2 *
+ (-Fwx))) / (coefficients.Jsm +
+ (coefficients.Js /
+ (coefficients.Gs * coefficients.Gs)))
+
+ # Then solve for alphad
+ alphad = (coefficients.rs * coefficients.Jdm * coefficients.Gd3 * alphas +
+ coefficients.rp * coefficients.Ktd * Id * coefficients.Gd -
+ coefficients.rw * coefficients.rp * coefficients.Gd * Fwx *
+ coefficients.Gd) / (coefficients.rp * coefficients.Jdm)
+
+ F = Rthetaplusthetas @ jax.numpy.array([Fwx, Fwy])
+
+ torque = force_cross(mounting_location, F)
+
+ X_dot_contribution = jax.numpy.hstack((jax.numpy.zeros(
+ (4, )), ) * (module_index) + (jax.numpy.array([
+ X_module[dynamics.STATE_OMEGAS0],
+ X_module[dynamics.STATE_OMEGAD0],
+ alphas,
+ alphad,
+ ]), ) + (jax.numpy.zeros((4, )), ) * (3 - module_index) + (
+ jax.numpy.zeros((3, )),
+ F / coefficients.m,
+ jax.numpy.array([torque / coefficients.J, 0, 0, 0]),
+ ))
+
+ return X_dot_contribution
+
+
+@partial(jax.jit, static_argnames=['coefficients'])
+def full_dynamics(coefficients: CoefficientsType, X, U):
+ Rtheta = R(X[dynamics.STATE_THETA])
+
+ module0 = full_module_physics(
+ coefficients, Rtheta, 0,
+ jax.numpy.array(
+ [coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]),
+ X, U)
+ module1 = full_module_physics(
+ coefficients, Rtheta, 1,
+ jax.numpy.array(
+ [-coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]),
+ X, U)
+ module2 = full_module_physics(
+ coefficients, Rtheta, 2,
+ jax.numpy.array(
+ [-coefficients.robot_width / 2.0,
+ -coefficients.robot_width / 2.0]), X, U)
+ module3 = full_module_physics(
+ coefficients, Rtheta, 3,
+ jax.numpy.array(
+ [coefficients.robot_width / 2.0, -coefficients.robot_width / 2.0]),
+ X, U)
+
+ X_dot = module0 + module1 + module2 + module3
+
+ X_dot = X_dot.at[dynamics.STATE_X:dynamics.STATE_THETA + 1].set(
+ jax.numpy.array([
+ X[dynamics.STATE_VX],
+ X[dynamics.STATE_VY],
+ X[dynamics.STATE_OMEGA],
+ ]))
+
+ return X_dot
+
+
+def velocity_module_physics(coefficients: dict, Rtheta, module_index: int,
+ mounting_location, X, U):
+ X_module = X[module_index * 2:(module_index + 1) * 2]
+ Is = U[2 * module_index + 0]
+ Id = U[2 * module_index + 1]
+
+ Rthetaplusthetas = R(X[dynamics.VELOCITY_STATE_THETA] +
+ X_module[dynamics.VELOCITY_STATE_THETAS0])
+
+ caster_vector = jax.numpy.array([-coefficients.caster, 0.0])
+
+ robot_velocity = X[dynamics.VELOCITY_STATE_VX:dynamics.VELOCITY_STATE_VY +
+ 1]
+
+ contact_patch_velocity = (
+ angle_cross(Rtheta @ mounting_location,
+ X[dynamics.VELOCITY_STATE_OMEGA]) + robot_velocity +
+ angle_cross(Rthetaplusthetas @ caster_vector,
+ (X[dynamics.VELOCITY_STATE_OMEGA] +
+ X_module[dynamics.VELOCITY_STATE_OMEGAS0])))
+
+ wheel_ground_velocity = Rthetaplusthetas.T @ contact_patch_velocity
+
+ slip_angle = jax.numpy.sin(
+ -soft_atan2(wheel_ground_velocity[1], wheel_ground_velocity[0]))
+
+ Fwx = (coefficients.Ktd / (coefficients.Gd * coefficients.rw)) * Id
+ Fwy = coefficients.Cy * slip_angle
+
+ softsign_velocity = softsign(wheel_ground_velocity[0], 100)
+
+ Ms = -Fwy * (
+ (softsign_velocity * coefficients.contact_patch_length / 3.0) +
+ coefficients.caster)
+
+ alphas = (Ms + coefficients.Kts * Is / coefficients.Gs +
+ (-coefficients.wb + (coefficients.rs + coefficients.rp) *
+ (1 - coefficients.rb1 / coefficients.rp)) *
+ (coefficients.rw / coefficients.rb2 *
+ (-Fwx))) / (coefficients.Jsm +
+ (coefficients.Js /
+ (coefficients.Gs * coefficients.Gs)))
+
+ F = Rthetaplusthetas @ jax.numpy.array([Fwx, Fwy])
+
+ torque = force_cross(mounting_location, F)
+
+ X_dot_contribution = jax.numpy.hstack((jax.numpy.zeros(
+ (2, )), ) * (module_index) + (jax.numpy.array([
+ X_module[dynamics.VELOCITY_STATE_OMEGAS0],
+ alphas,
+ ]), ) + (jax.numpy.zeros((2, )), ) * (3 - module_index) + (
+ jax.numpy.zeros((1, )),
+ F / coefficients.m,
+ jax.numpy.array([torque / coefficients.J]),
+ ))
+
+ return X_dot_contribution
+
+
+@partial(jax.jit, static_argnames=['coefficients'])
+def velocity_dynamics(coefficients: CoefficientsType, X, U):
+ Rtheta = R(X[dynamics.VELOCITY_STATE_THETA])
+
+ module0 = velocity_module_physics(
+ coefficients, Rtheta, 0,
+ jax.numpy.array(
+ [coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]),
+ X, U)
+ module1 = velocity_module_physics(
+ coefficients, Rtheta, 1,
+ jax.numpy.array(
+ [-coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]),
+ X, U)
+ module2 = velocity_module_physics(
+ coefficients, Rtheta, 2,
+ jax.numpy.array(
+ [-coefficients.robot_width / 2.0,
+ -coefficients.robot_width / 2.0]), X, U)
+ module3 = velocity_module_physics(
+ coefficients, Rtheta, 3,
+ jax.numpy.array(
+ [coefficients.robot_width / 2.0, -coefficients.robot_width / 2.0]),
+ X, U)
+
+ X_dot = module0 + module1 + module2 + module3
+
+ return X_dot.at[dynamics.VELOCITY_STATE_THETA].set(
+ X[dynamics.VELOCITY_STATE_OMEGA])