Add JAX cost, and test to verify correctness
Turns out there's a bug in the casadi too, so fix that while we are
here. We were using the unrotated radius for computing the cost
function torque instead of the rotated radius.
Split the MPC code out into a separate file as well to make it easier to
pull into other things.
Change-Id: Iee6c9999fa8b6a91d6963af4edba5cf92a085e9b
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/jax_dynamics.py b/frc971/control_loops/swerve/jax_dynamics.py
index 6e1305d..19a48de 100644
--- a/frc971/control_loops/swerve/jax_dynamics.py
+++ b/frc971/control_loops/swerve/jax_dynamics.py
@@ -352,3 +352,73 @@
X[STATE_VY],
X[STATE_OMEGA],
])
+
+
+def mpc_cost(coefficients: CoefficientsType, X, U, goal):
+ J = 0
+
+ rnorm = jax.numpy.linalg.norm(goal[0:2])
+
+ vnorm = jax.lax.select(rnorm > 0.0001, goal[0:2] / rnorm,
+ jax.numpy.array([1.0, 0.0]))
+ vperp = jax.lax.select(rnorm > 0.0001,
+ jax.numpy.array([-vnorm[1], vnorm[0]]),
+ jax.numpy.array([0.0, 1.0]))
+
+ velocity_error = goal[0:2] - X[VELOCITY_STATE_VX:VELOCITY_STATE_VY + 1]
+
+ # TODO(austin): Do we want to do something more special for 0?
+
+ J += 75 * (jax.numpy.dot(velocity_error, vnorm)**2.0)
+ J += 1500 * (jax.numpy.dot(velocity_error, vperp)**2.0)
+ J += 1000 * (goal[2] - X[VELOCITY_STATE_OMEGA])**2.0
+
+ kSteerVelocityGain = 0.10
+ J += kSteerVelocityGain * (X[VELOCITY_STATE_OMEGAS0])**2.0
+ J += kSteerVelocityGain * (X[VELOCITY_STATE_OMEGAS1])**2.0
+ J += kSteerVelocityGain * (X[VELOCITY_STATE_OMEGAS2])**2.0
+ J += kSteerVelocityGain * (X[VELOCITY_STATE_OMEGAS3])**2.0
+
+ mounting_locations = jax.numpy.array(
+ [[coefficients.robot_width / 2.0, coefficients.robot_width / 2.0],
+ [-coefficients.robot_width / 2.0, coefficients.robot_width / 2.0],
+ [-coefficients.robot_width / 2.0, -coefficients.robot_width / 2.0],
+ [coefficients.robot_width / 2.0, -coefficients.robot_width / 2.0]])
+
+ Rtheta = R(X[VELOCITY_STATE_THETA])
+ _, F0, torque0 = velocity_module_physics(coefficients, Rtheta, 0,
+ mounting_locations[0], X, U)
+ _, F1, torque1 = velocity_module_physics(coefficients, Rtheta, 1,
+ mounting_locations[1], X, U)
+ _, F2, torque2 = velocity_module_physics(coefficients, Rtheta, 2,
+ mounting_locations[2], X, U)
+ _, F3, torque3 = velocity_module_physics(coefficients, Rtheta, 3,
+ mounting_locations[3], X, U)
+
+ forces = [F0, F1, F2, F3]
+
+ F = (F0 + F1 + F2 + F3)
+ torque = (torque0 + torque1 + torque2 + torque3)
+
+ def force_cross(torque, r):
+ r_squared_norm = jax.numpy.inner(r, r)
+
+ return jax.numpy.array(
+ [-r[1] * torque / r_squared_norm, r[0] * torque / r_squared_norm])
+
+ # TODO(austin): Are these penalties reasonable? Do they give us a decent time constant?
+ for i in range(4):
+ desired_force = F / 4.0 + force_cross(
+ torque / 4.0, Rtheta @ mounting_locations[i, :])
+ force_error = desired_force - forces[i]
+ J += 0.01 * jax.numpy.inner(force_error, force_error)
+
+ for i in range(4):
+ Is = U[2 * i + 0]
+ Id = U[2 * i + 1]
+ # Steer
+ J += ((Is + STEER_CURRENT_COUPLING_FACTOR * Id)**2.0) / 100000.0
+ # Drive
+ J += (Id**2.0) / 1000.0
+
+ return J