| #!/usr/bin/env python3 |
| |
| from functools import partial |
| from collections import namedtuple |
| import jax |
| |
| from frc971.control_loops.python.control_loop import KrakenFOC |
| from frc971.control_loops.swerve.dynamics_constants import * |
| |
| # Note: this physics needs to match the symengine code. We have tests that |
| # confirm they match in all the cases we care about. |
| |
| CoefficientsType = namedtuple('CoefficientsType', [ |
| 'Cx', |
| 'Cy', |
| 'rw', |
| 'm', |
| 'J', |
| 'Gd1', |
| 'rs', |
| 'rp', |
| 'Gd2', |
| 'rb1', |
| 'rb2', |
| 'Gd3', |
| 'Gd', |
| 'Js', |
| 'Gs', |
| 'wb', |
| 'Jdm', |
| 'Jsm', |
| 'Kts', |
| 'Ktd', |
| 'robot_width', |
| 'caster', |
| 'contact_patch_length', |
| ]) |
| |
| |
| def Coefficients( |
| Cx: float = 25.0 * 9.8 / 4.0 / 0.05, |
| Cy: float = 5 * 9.8 / 0.05 / 4.0, |
| rw: float = 2 * 0.0254, |
| |
| # base is 20 kg without battery |
| m: float = 25.0, |
| J: float = 6.0, |
| Gd1: float = 12.0 / 42.0, |
| rs: float = 28.0 / 20.0 / 2.0, |
| rp: float = 18.0 / 20.0 / 2.0, |
| |
| # 15 / 45 bevel ratio, calculated using python script ported over to |
| # GetBevelPitchRadius(double) |
| # TODO(Justin): Use the function instead of computed constantss |
| rb1: float = 0.3805473, |
| rb2: float = 1.14164, |
| Js: float = 0.001, |
| Gs: float = 35.0 / 468.0, |
| wb: float = 0.725, |
| drive_motor=KrakenFOC(), |
| steer_motor=KrakenFOC(), |
| robot_width: float = 24.75 * 0.0254, |
| caster: float = 0.01, |
| contact_patch_length: float = 0.02, |
| ) -> CoefficientsType: |
| |
| Gd2 = rs / rp |
| Gd3 = rb1 / rb2 |
| Gd = Gd1 * Gd2 * Gd3 |
| |
| Jdm = drive_motor.motor_inertia |
| Jsm = steer_motor.motor_inertia |
| Kts = steer_motor.Kt |
| Ktd = drive_motor.Kt |
| |
| return CoefficientsType( |
| Cx=Cx, |
| Cy=Cy, |
| rw=rw, |
| m=m, |
| J=J, |
| Gd1=Gd1, |
| rs=rs, |
| rp=rp, |
| Gd2=Gd2, |
| rb1=rb1, |
| rb2=rb2, |
| Gd3=Gd3, |
| Gd=Gd, |
| Js=Js, |
| Gs=Gs, |
| wb=wb, |
| Jdm=Jdm, |
| Jsm=Jsm, |
| Kts=Kts, |
| Ktd=Ktd, |
| robot_width=robot_width, |
| caster=caster, |
| contact_patch_length=contact_patch_length, |
| ) |
| |
| |
| def R(theta): |
| stheta = jax.numpy.sin(theta) |
| ctheta = jax.numpy.cos(theta) |
| return jax.numpy.array([[ctheta, -stheta], [stheta, ctheta]]) |
| |
| |
| def angle_cross(vector, omega): |
| return jax.numpy.array([-vector[1] * omega, vector[0] * omega]) |
| |
| |
| def force_cross(r, f): |
| return r[0] * f[1] - r[1] * f[0] |
| |
| |
| def softsign(x, gain): |
| return 1 - 2.0 * jax.nn.sigmoid(-gain * x) |
| |
| |
| def soft_atan2(y, x): |
| kMaxLogGain = 1.0 / 0.05 |
| kAbsLogGain = 1.0 / 0.01 |
| |
| softabs_x = x * softsign(x, kAbsLogGain) |
| |
| return jax.numpy.arctan2( |
| y, |
| jax.scipy.special.logsumexp( |
| jax.numpy.array([1.0, softabs_x * kMaxLogGain])) / kMaxLogGain) |
| |
| |
| def full_module_physics(coefficients: CoefficientsType, Rtheta, |
| module_index: int, mounting_location, X, U): |
| X_module = X[module_index * 4:(module_index + 1) * 4] |
| Is = U[2 * module_index + 0] |
| Id = U[2 * module_index + 1] |
| |
| Rthetaplusthetas = R(X[STATE_THETA] + X_module[STATE_THETAS0]) |
| |
| caster_vector = jax.numpy.array([-coefficients.caster, 0.0]) |
| |
| robot_velocity = X[STATE_VX:STATE_VY + 1] |
| |
| contact_patch_velocity = ( |
| angle_cross(Rtheta @ mounting_location, X[STATE_OMEGA]) + |
| robot_velocity + |
| angle_cross(Rthetaplusthetas @ caster_vector, |
| (X[STATE_OMEGA] + X_module[STATE_OMEGAS0]))) |
| |
| wheel_ground_velocity = Rthetaplusthetas.T @ contact_patch_velocity |
| |
| wheel_velocity = jax.numpy.array( |
| [coefficients.rw * X_module[STATE_OMEGAD0], 0.0]) |
| |
| wheel_slip_velocity = wheel_velocity - wheel_ground_velocity |
| |
| slip_angle = jax.numpy.sin( |
| -soft_atan2(wheel_ground_velocity[1], wheel_ground_velocity[0])) |
| |
| slip_ratio = (coefficients.rw * X_module[STATE_OMEGAD0] - |
| wheel_ground_velocity[0]) / jax.numpy.max( |
| jax.numpy.array( |
| [0.02, jax.numpy.abs(wheel_ground_velocity[0])])) |
| |
| Fwx = coefficients.Cx * slip_ratio |
| Fwy = coefficients.Cy * slip_angle |
| |
| softsign_velocity = softsign(wheel_ground_velocity[0], 100) |
| |
| Ms = -Fwy * ( |
| (softsign_velocity * coefficients.contact_patch_length / 3.0) + |
| coefficients.caster) |
| |
| alphas = (Ms + coefficients.Kts * Is / coefficients.Gs + |
| (-coefficients.wb + (coefficients.rs + coefficients.rp) * |
| (1 - coefficients.rb1 / coefficients.rp)) * |
| (coefficients.rw / coefficients.rb2 * |
| (-Fwx))) / (coefficients.Jsm + |
| (coefficients.Js / |
| (coefficients.Gs * coefficients.Gs))) |
| |
| # Then solve for alphad |
| alphad = (coefficients.rs * coefficients.Jdm * coefficients.Gd3 * alphas + |
| coefficients.rp * coefficients.Ktd * Id * coefficients.Gd - |
| coefficients.rw * coefficients.rp * coefficients.Gd * Fwx * |
| coefficients.Gd) / (coefficients.rp * coefficients.Jdm) |
| |
| F = Rthetaplusthetas @ jax.numpy.array([Fwx, Fwy]) |
| |
| torque = force_cross(Rtheta @ mounting_location, F) |
| |
| X_dot_contribution = jax.numpy.hstack((jax.numpy.zeros( |
| (4, )), ) * (module_index) + (jax.numpy.array([ |
| X_module[STATE_OMEGAS0], |
| X_module[STATE_OMEGAD0], |
| alphas, |
| alphad, |
| ]), ) + (jax.numpy.zeros((4, )), ) * (3 - module_index) + ( |
| jax.numpy.zeros((3, )), |
| F / coefficients.m, |
| jax.numpy.array([torque / coefficients.J, 0, 0, 0]), |
| )) |
| |
| return X_dot_contribution |
| |
| |
| @partial(jax.jit, static_argnames=['coefficients']) |
| def full_dynamics(coefficients: CoefficientsType, X, U): |
| Rtheta = R(X[STATE_THETA]) |
| |
| module0 = full_module_physics( |
| coefficients, Rtheta, 0, |
| jax.numpy.array( |
| [coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]), |
| X, U) |
| module1 = full_module_physics( |
| coefficients, Rtheta, 1, |
| jax.numpy.array( |
| [-coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]), |
| X, U) |
| module2 = full_module_physics( |
| coefficients, Rtheta, 2, |
| jax.numpy.array( |
| [-coefficients.robot_width / 2.0, |
| -coefficients.robot_width / 2.0]), X, U) |
| module3 = full_module_physics( |
| coefficients, Rtheta, 3, |
| jax.numpy.array( |
| [coefficients.robot_width / 2.0, -coefficients.robot_width / 2.0]), |
| X, U) |
| |
| X_dot = module0 + module1 + module2 + module3 |
| |
| X_dot = X_dot.at[STATE_X:STATE_THETA + 1].set( |
| jax.numpy.array([ |
| X[STATE_VX], |
| X[STATE_VY], |
| X[STATE_OMEGA], |
| ])) |
| |
| return X_dot |
| |
| |
| def velocity_module_physics(coefficients: CoefficientsType, |
| Rtheta: jax.typing.ArrayLike, module_index: int, |
| mounting_location: jax.typing.ArrayLike, |
| X: jax.typing.ArrayLike, U: jax.typing.ArrayLike): |
| X_module = X[module_index * 2:(module_index + 1) * 2] |
| Is = U[2 * module_index + 0] |
| Id = U[2 * module_index + 1] |
| |
| rotated_mounting_location = Rtheta @ mounting_location |
| |
| Rthetaplusthetas = R(X[VELOCITY_STATE_THETA] + |
| X_module[VELOCITY_STATE_THETAS0]) |
| |
| caster_vector = jax.numpy.array([-coefficients.caster, 0.0]) |
| |
| robot_velocity = X[VELOCITY_STATE_VX:VELOCITY_STATE_VY + 1] |
| |
| contact_patch_velocity = ( |
| angle_cross(rotated_mounting_location, X[VELOCITY_STATE_OMEGA]) + |
| robot_velocity + angle_cross( |
| Rthetaplusthetas @ caster_vector, |
| (X[VELOCITY_STATE_OMEGA] + X_module[VELOCITY_STATE_OMEGAS0]))) |
| |
| # Velocity of the contact patch over the field projected into the direction |
| # of the wheel. |
| wheel_ground_velocity = Rthetaplusthetas.T @ contact_patch_velocity |
| |
| slip_angle = jax.numpy.sin( |
| -soft_atan2(wheel_ground_velocity[1], wheel_ground_velocity[0])) |
| |
| Fwx = (coefficients.Ktd / (coefficients.Gd * coefficients.rw)) * Id |
| Fwy = coefficients.Cy * slip_angle |
| |
| softsign_velocity = softsign(wheel_ground_velocity[0], 100.0) |
| |
| Ms = -Fwy * ( |
| (softsign_velocity * coefficients.contact_patch_length / 3.0) + |
| coefficients.caster) |
| |
| alphas = (Ms + coefficients.Kts * Is / coefficients.Gs + |
| (-coefficients.wb + (coefficients.rs + coefficients.rp) * |
| (1 - coefficients.rb1 / coefficients.rp)) * |
| (coefficients.rw / coefficients.rb2 * |
| (-Fwx))) / (coefficients.Jsm + |
| (coefficients.Js / |
| (coefficients.Gs * coefficients.Gs))) |
| |
| F = Rthetaplusthetas @ jax.numpy.array([Fwx, Fwy]) |
| |
| torque = force_cross(rotated_mounting_location, F) |
| |
| X_dot_contribution = jax.numpy.hstack((jax.numpy.zeros( |
| (2, )), ) * (module_index) + (jax.numpy.array([ |
| X_module[VELOCITY_STATE_OMEGAS0], |
| alphas, |
| ]), ) + (jax.numpy.zeros((2, )), ) * (3 - module_index) + ( |
| jax.numpy.zeros((1, )), |
| F / coefficients.m, |
| jax.numpy.array([torque / coefficients.J]), |
| )) |
| |
| return X_dot_contribution, F, torque |
| |
| |
| @partial(jax.jit, static_argnames=['coefficients']) |
| def velocity_dynamics(coefficients: CoefficientsType, X: jax.typing.ArrayLike, |
| U: jax.typing.ArrayLike): |
| Rtheta = R(X[VELOCITY_STATE_THETA]) |
| |
| module0, _, _ = velocity_module_physics( |
| coefficients, Rtheta, 0, |
| jax.numpy.array( |
| [coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]), |
| X, U) |
| module1, _, _ = velocity_module_physics( |
| coefficients, Rtheta, 1, |
| jax.numpy.array( |
| [-coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]), |
| X, U) |
| module2, _, _ = velocity_module_physics( |
| coefficients, Rtheta, 2, |
| jax.numpy.array( |
| [-coefficients.robot_width / 2.0, |
| -coefficients.robot_width / 2.0]), X, U) |
| module3, _, _ = velocity_module_physics( |
| coefficients, Rtheta, 3, |
| jax.numpy.array( |
| [coefficients.robot_width / 2.0, -coefficients.robot_width / 2.0]), |
| X, U) |
| |
| X_dot = module0 + module1 + module2 + module3 |
| |
| return X_dot.at[VELOCITY_STATE_THETA].set(X[VELOCITY_STATE_OMEGA]) |
| |
| |
| def to_velocity_state(X): |
| return jax.numpy.array([ |
| X[STATE_THETAS0], |
| X[STATE_OMEGAS0], |
| X[STATE_THETAS1], |
| X[STATE_OMEGAS1], |
| X[STATE_THETAS2], |
| X[STATE_OMEGAS2], |
| X[STATE_THETAS3], |
| X[STATE_OMEGAS3], |
| X[STATE_THETA], |
| X[STATE_VX], |
| X[STATE_VY], |
| X[STATE_OMEGA], |
| ]) |