Fix NaNs in physics
File "frc971/control_loops/swerve/jax_dynamics.py", line 116, in softsign
return -2 / (1 + jax.numpy.exp(gain * x)) + 1
jax._src.checkify.NaNError: nan generated by primitive: mul.
JAX has a sigmoid function to do this directly which fixes the issue.
Change-Id: I7dbea6194a40c1f1036934c5e574a5c02cca5068
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/jax_dynamics.py b/frc971/control_loops/swerve/jax_dynamics.py
index 8b05ef4..76a00d5 100644
--- a/frc971/control_loops/swerve/jax_dynamics.py
+++ b/frc971/control_loops/swerve/jax_dynamics.py
@@ -115,14 +115,14 @@
def softsign(x, gain):
- return -2 / (1 + jax.numpy.exp(gain * x)) + 1
+ return 1 - 2.0 * jax.nn.sigmoid(-gain * x)
def soft_atan2(y, x):
kMaxLogGain = 1.0 / 0.05
kAbsLogGain = 1.0 / 0.01
- softabs_x = x * (1.0 - 2.0 / (1 + jax.numpy.exp(kAbsLogGain * x)))
+ softabs_x = x * softsign(x, kAbsLogGain)
return jax.numpy.arctan2(
y,
@@ -272,7 +272,7 @@
Fwx = (coefficients.Ktd / (coefficients.Gd * coefficients.rw)) * Id
Fwy = coefficients.Cy * slip_angle
- softsign_velocity = softsign(wheel_ground_velocity[0], 100)
+ softsign_velocity = softsign(wheel_ground_velocity[0], 100.0)
Ms = -Fwy * (
(softsign_velocity * coefficients.contact_patch_length / 3.0) +