Make JAX dynamics code not depend on casadi code
Fewer dependencies is better.
Change-Id: I60f931347d27b75038387fde11a42ee98c9c1b99
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/jax_dynamics.py b/frc971/control_loops/swerve/jax_dynamics.py
index dfadbda..6e1305d 100644
--- a/frc971/control_loops/swerve/jax_dynamics.py
+++ b/frc971/control_loops/swerve/jax_dynamics.py
@@ -4,8 +4,8 @@
from collections import namedtuple
import jax
-from frc971.control_loops.swerve import dynamics
from frc971.control_loops.python.control_loop import KrakenFOC
+from frc971.control_loops.swerve.dynamics_constants import *
# Note: this physics needs to match the symengine code. We have tests that
# confirm they match in all the cases we care about.
@@ -130,36 +130,35 @@
jax.numpy.array([1.0, softabs_x * kMaxLogGain])) / kMaxLogGain)
-def full_module_physics(coefficients: dict, Rtheta, module_index: int,
- mounting_location, X, U):
+def full_module_physics(coefficients: CoefficientsType, Rtheta,
+ module_index: int, mounting_location, X, U):
X_module = X[module_index * 4:(module_index + 1) * 4]
Is = U[2 * module_index + 0]
Id = U[2 * module_index + 1]
- Rthetaplusthetas = R(X[dynamics.STATE_THETA] +
- X_module[dynamics.STATE_THETAS0])
+ Rthetaplusthetas = R(X[STATE_THETA] + X_module[STATE_THETAS0])
caster_vector = jax.numpy.array([-coefficients.caster, 0.0])
- robot_velocity = X[dynamics.STATE_VX:dynamics.STATE_VY + 1]
+ robot_velocity = X[STATE_VX:STATE_VY + 1]
contact_patch_velocity = (
- angle_cross(Rtheta @ mounting_location, X[dynamics.STATE_OMEGA]) +
- robot_velocity + angle_cross(
- Rthetaplusthetas @ caster_vector,
- (X[dynamics.STATE_OMEGA] + X_module[dynamics.STATE_OMEGAS0])))
+ angle_cross(Rtheta @ mounting_location, X[STATE_OMEGA]) +
+ robot_velocity +
+ angle_cross(Rthetaplusthetas @ caster_vector,
+ (X[STATE_OMEGA] + X_module[STATE_OMEGAS0])))
wheel_ground_velocity = Rthetaplusthetas.T @ contact_patch_velocity
wheel_velocity = jax.numpy.array(
- [coefficients.rw * X_module[dynamics.STATE_OMEGAD0], 0.0])
+ [coefficients.rw * X_module[STATE_OMEGAD0], 0.0])
wheel_slip_velocity = wheel_velocity - wheel_ground_velocity
slip_angle = jax.numpy.sin(
-soft_atan2(wheel_ground_velocity[1], wheel_ground_velocity[0]))
- slip_ratio = (coefficients.rw * X_module[dynamics.STATE_OMEGAD0] -
+ slip_ratio = (coefficients.rw * X_module[STATE_OMEGAD0] -
wheel_ground_velocity[0]) / jax.numpy.max(
jax.numpy.array(
[0.02, jax.numpy.abs(wheel_ground_velocity[0])]))
@@ -193,8 +192,8 @@
X_dot_contribution = jax.numpy.hstack((jax.numpy.zeros(
(4, )), ) * (module_index) + (jax.numpy.array([
- X_module[dynamics.STATE_OMEGAS0],
- X_module[dynamics.STATE_OMEGAD0],
+ X_module[STATE_OMEGAS0],
+ X_module[STATE_OMEGAD0],
alphas,
alphad,
]), ) + (jax.numpy.zeros((4, )), ) * (3 - module_index) + (
@@ -208,7 +207,7 @@
@partial(jax.jit, static_argnames=['coefficients'])
def full_dynamics(coefficients: CoefficientsType, X, U):
- Rtheta = R(X[dynamics.STATE_THETA])
+ Rtheta = R(X[STATE_THETA])
module0 = full_module_physics(
coefficients, Rtheta, 0,
@@ -233,37 +232,41 @@
X_dot = module0 + module1 + module2 + module3
- X_dot = X_dot.at[dynamics.STATE_X:dynamics.STATE_THETA + 1].set(
+ X_dot = X_dot.at[STATE_X:STATE_THETA + 1].set(
jax.numpy.array([
- X[dynamics.STATE_VX],
- X[dynamics.STATE_VY],
- X[dynamics.STATE_OMEGA],
+ X[STATE_VX],
+ X[STATE_VY],
+ X[STATE_OMEGA],
]))
return X_dot
-def velocity_module_physics(coefficients: dict, Rtheta, module_index: int,
- mounting_location, X, U):
+def velocity_module_physics(coefficients: CoefficientsType,
+ Rtheta: jax.typing.ArrayLike, module_index: int,
+ mounting_location: jax.typing.ArrayLike,
+ X: jax.typing.ArrayLike, U: jax.typing.ArrayLike):
X_module = X[module_index * 2:(module_index + 1) * 2]
Is = U[2 * module_index + 0]
Id = U[2 * module_index + 1]
- Rthetaplusthetas = R(X[dynamics.VELOCITY_STATE_THETA] +
- X_module[dynamics.VELOCITY_STATE_THETAS0])
+ rotated_mounting_location = Rtheta @ mounting_location
+
+ Rthetaplusthetas = R(X[VELOCITY_STATE_THETA] +
+ X_module[VELOCITY_STATE_THETAS0])
caster_vector = jax.numpy.array([-coefficients.caster, 0.0])
- robot_velocity = X[dynamics.VELOCITY_STATE_VX:dynamics.VELOCITY_STATE_VY +
- 1]
+ robot_velocity = X[VELOCITY_STATE_VX:VELOCITY_STATE_VY + 1]
contact_patch_velocity = (
- angle_cross(Rtheta @ mounting_location,
- X[dynamics.VELOCITY_STATE_OMEGA]) + robot_velocity +
- angle_cross(Rthetaplusthetas @ caster_vector,
- (X[dynamics.VELOCITY_STATE_OMEGA] +
- X_module[dynamics.VELOCITY_STATE_OMEGAS0])))
+ angle_cross(rotated_mounting_location, X[VELOCITY_STATE_OMEGA]) +
+ robot_velocity + angle_cross(
+ Rthetaplusthetas @ caster_vector,
+ (X[VELOCITY_STATE_OMEGA] + X_module[VELOCITY_STATE_OMEGAS0])))
+ # Velocity of the contact patch over the field projected into the direction
+ # of the wheel.
wheel_ground_velocity = Rthetaplusthetas.T @ contact_patch_velocity
slip_angle = jax.numpy.sin(
@@ -288,11 +291,11 @@
F = Rthetaplusthetas @ jax.numpy.array([Fwx, Fwy])
- torque = force_cross(Rtheta @ mounting_location, F)
+ torque = force_cross(rotated_mounting_location, F)
X_dot_contribution = jax.numpy.hstack((jax.numpy.zeros(
(2, )), ) * (module_index) + (jax.numpy.array([
- X_module[dynamics.VELOCITY_STATE_OMEGAS0],
+ X_module[VELOCITY_STATE_OMEGAS0],
alphas,
]), ) + (jax.numpy.zeros((2, )), ) * (3 - module_index) + (
jax.numpy.zeros((1, )),
@@ -300,29 +303,30 @@
jax.numpy.array([torque / coefficients.J]),
))
- return X_dot_contribution
+ return X_dot_contribution, F, torque
@partial(jax.jit, static_argnames=['coefficients'])
-def velocity_dynamics(coefficients: CoefficientsType, X, U):
- Rtheta = R(X[dynamics.VELOCITY_STATE_THETA])
+def velocity_dynamics(coefficients: CoefficientsType, X: jax.typing.ArrayLike,
+ U: jax.typing.ArrayLike):
+ Rtheta = R(X[VELOCITY_STATE_THETA])
- module0 = velocity_module_physics(
+ module0, _, _ = velocity_module_physics(
coefficients, Rtheta, 0,
jax.numpy.array(
[coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]),
X, U)
- module1 = velocity_module_physics(
+ module1, _, _ = velocity_module_physics(
coefficients, Rtheta, 1,
jax.numpy.array(
[-coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]),
X, U)
- module2 = velocity_module_physics(
+ module2, _, _ = velocity_module_physics(
coefficients, Rtheta, 2,
jax.numpy.array(
[-coefficients.robot_width / 2.0,
-coefficients.robot_width / 2.0]), X, U)
- module3 = velocity_module_physics(
+ module3, _, _ = velocity_module_physics(
coefficients, Rtheta, 3,
jax.numpy.array(
[coefficients.robot_width / 2.0, -coefficients.robot_width / 2.0]),
@@ -330,5 +334,21 @@
X_dot = module0 + module1 + module2 + module3
- return X_dot.at[dynamics.VELOCITY_STATE_THETA].set(
- X[dynamics.VELOCITY_STATE_OMEGA])
+ return X_dot.at[VELOCITY_STATE_THETA].set(X[VELOCITY_STATE_OMEGA])
+
+
+def to_velocity_state(X):
+ return jax.numpy.array([
+ X[STATE_THETAS0],
+ X[STATE_OMEGAS0],
+ X[STATE_THETAS1],
+ X[STATE_OMEGAS1],
+ X[STATE_THETAS2],
+ X[STATE_OMEGAS2],
+ X[STATE_THETAS3],
+ X[STATE_OMEGAS3],
+ X[STATE_THETA],
+ X[STATE_VX],
+ X[STATE_VY],
+ X[STATE_OMEGA],
+ ])