Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 1 | #!/usr/bin/env python3 |
| 2 | |
| 3 | from functools import partial |
| 4 | from collections import namedtuple |
| 5 | import jax |
| 6 | |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 7 | from frc971.control_loops.python.control_loop import KrakenFOC |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame^] | 8 | from frc971.control_loops.swerve.dynamics_constants import * |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 9 | |
| 10 | # Note: this physics needs to match the symengine code. We have tests that |
| 11 | # confirm they match in all the cases we care about. |
| 12 | |
| 13 | CoefficientsType = namedtuple('CoefficientsType', [ |
| 14 | 'Cx', |
| 15 | 'Cy', |
| 16 | 'rw', |
| 17 | 'm', |
| 18 | 'J', |
| 19 | 'Gd1', |
| 20 | 'rs', |
| 21 | 'rp', |
| 22 | 'Gd2', |
| 23 | 'rb1', |
| 24 | 'rb2', |
| 25 | 'Gd3', |
| 26 | 'Gd', |
| 27 | 'Js', |
| 28 | 'Gs', |
| 29 | 'wb', |
| 30 | 'Jdm', |
| 31 | 'Jsm', |
| 32 | 'Kts', |
| 33 | 'Ktd', |
| 34 | 'robot_width', |
| 35 | 'caster', |
| 36 | 'contact_patch_length', |
| 37 | ]) |
| 38 | |
| 39 | |
| 40 | def Coefficients( |
| 41 | Cx: float = 25.0 * 9.8 / 4.0 / 0.05, |
| 42 | Cy: float = 5 * 9.8 / 0.05 / 4.0, |
| 43 | rw: float = 2 * 0.0254, |
| 44 | |
| 45 | # base is 20 kg without battery |
| 46 | m: float = 25.0, |
| 47 | J: float = 6.0, |
| 48 | Gd1: float = 12.0 / 42.0, |
| 49 | rs: float = 28.0 / 20.0 / 2.0, |
| 50 | rp: float = 18.0 / 20.0 / 2.0, |
| 51 | |
| 52 | # 15 / 45 bevel ratio, calculated using python script ported over to |
| 53 | # GetBevelPitchRadius(double) |
| 54 | # TODO(Justin): Use the function instead of computed constantss |
| 55 | rb1: float = 0.3805473, |
| 56 | rb2: float = 1.14164, |
| 57 | Js: float = 0.001, |
| 58 | Gs: float = 35.0 / 468.0, |
| 59 | wb: float = 0.725, |
| 60 | drive_motor=KrakenFOC(), |
| 61 | steer_motor=KrakenFOC(), |
| 62 | robot_width: float = 24.75 * 0.0254, |
| 63 | caster: float = 0.01, |
| 64 | contact_patch_length: float = 0.02, |
| 65 | ) -> CoefficientsType: |
| 66 | |
| 67 | Gd2 = rs / rp |
| 68 | Gd3 = rb1 / rb2 |
| 69 | Gd = Gd1 * Gd2 * Gd3 |
| 70 | |
| 71 | Jdm = drive_motor.motor_inertia |
| 72 | Jsm = steer_motor.motor_inertia |
| 73 | Kts = steer_motor.Kt |
| 74 | Ktd = drive_motor.Kt |
| 75 | |
| 76 | return CoefficientsType( |
| 77 | Cx=Cx, |
| 78 | Cy=Cy, |
| 79 | rw=rw, |
| 80 | m=m, |
| 81 | J=J, |
| 82 | Gd1=Gd1, |
| 83 | rs=rs, |
| 84 | rp=rp, |
| 85 | Gd2=Gd2, |
| 86 | rb1=rb1, |
| 87 | rb2=rb2, |
| 88 | Gd3=Gd3, |
| 89 | Gd=Gd, |
| 90 | Js=Js, |
| 91 | Gs=Gs, |
| 92 | wb=wb, |
| 93 | Jdm=Jdm, |
| 94 | Jsm=Jsm, |
| 95 | Kts=Kts, |
| 96 | Ktd=Ktd, |
| 97 | robot_width=robot_width, |
| 98 | caster=caster, |
| 99 | contact_patch_length=contact_patch_length, |
| 100 | ) |
| 101 | |
| 102 | |
| 103 | def R(theta): |
| 104 | stheta = jax.numpy.sin(theta) |
| 105 | ctheta = jax.numpy.cos(theta) |
| 106 | return jax.numpy.array([[ctheta, -stheta], [stheta, ctheta]]) |
| 107 | |
| 108 | |
| 109 | def angle_cross(vector, omega): |
| 110 | return jax.numpy.array([-vector[1] * omega, vector[0] * omega]) |
| 111 | |
| 112 | |
| 113 | def force_cross(r, f): |
| 114 | return r[0] * f[1] - r[1] * f[0] |
| 115 | |
| 116 | |
| 117 | def softsign(x, gain): |
Austin Schuh | ffb6db9 | 2024-09-04 14:00:48 -0700 | [diff] [blame] | 118 | return 1 - 2.0 * jax.nn.sigmoid(-gain * x) |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 119 | |
| 120 | |
| 121 | def soft_atan2(y, x): |
| 122 | kMaxLogGain = 1.0 / 0.05 |
| 123 | kAbsLogGain = 1.0 / 0.01 |
| 124 | |
Austin Schuh | ffb6db9 | 2024-09-04 14:00:48 -0700 | [diff] [blame] | 125 | softabs_x = x * softsign(x, kAbsLogGain) |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 126 | |
| 127 | return jax.numpy.arctan2( |
| 128 | y, |
| 129 | jax.scipy.special.logsumexp( |
| 130 | jax.numpy.array([1.0, softabs_x * kMaxLogGain])) / kMaxLogGain) |
| 131 | |
| 132 | |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame^] | 133 | def full_module_physics(coefficients: CoefficientsType, Rtheta, |
| 134 | module_index: int, mounting_location, X, U): |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 135 | X_module = X[module_index * 4:(module_index + 1) * 4] |
| 136 | Is = U[2 * module_index + 0] |
| 137 | Id = U[2 * module_index + 1] |
| 138 | |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame^] | 139 | Rthetaplusthetas = R(X[STATE_THETA] + X_module[STATE_THETAS0]) |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 140 | |
| 141 | caster_vector = jax.numpy.array([-coefficients.caster, 0.0]) |
| 142 | |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame^] | 143 | robot_velocity = X[STATE_VX:STATE_VY + 1] |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 144 | |
| 145 | contact_patch_velocity = ( |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame^] | 146 | angle_cross(Rtheta @ mounting_location, X[STATE_OMEGA]) + |
| 147 | robot_velocity + |
| 148 | angle_cross(Rthetaplusthetas @ caster_vector, |
| 149 | (X[STATE_OMEGA] + X_module[STATE_OMEGAS0]))) |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 150 | |
| 151 | wheel_ground_velocity = Rthetaplusthetas.T @ contact_patch_velocity |
| 152 | |
| 153 | wheel_velocity = jax.numpy.array( |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame^] | 154 | [coefficients.rw * X_module[STATE_OMEGAD0], 0.0]) |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 155 | |
| 156 | wheel_slip_velocity = wheel_velocity - wheel_ground_velocity |
| 157 | |
| 158 | slip_angle = jax.numpy.sin( |
| 159 | -soft_atan2(wheel_ground_velocity[1], wheel_ground_velocity[0])) |
| 160 | |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame^] | 161 | slip_ratio = (coefficients.rw * X_module[STATE_OMEGAD0] - |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 162 | wheel_ground_velocity[0]) / jax.numpy.max( |
| 163 | jax.numpy.array( |
| 164 | [0.02, jax.numpy.abs(wheel_ground_velocity[0])])) |
| 165 | |
| 166 | Fwx = coefficients.Cx * slip_ratio |
| 167 | Fwy = coefficients.Cy * slip_angle |
| 168 | |
| 169 | softsign_velocity = softsign(wheel_ground_velocity[0], 100) |
| 170 | |
| 171 | Ms = -Fwy * ( |
| 172 | (softsign_velocity * coefficients.contact_patch_length / 3.0) + |
| 173 | coefficients.caster) |
| 174 | |
| 175 | alphas = (Ms + coefficients.Kts * Is / coefficients.Gs + |
| 176 | (-coefficients.wb + (coefficients.rs + coefficients.rp) * |
| 177 | (1 - coefficients.rb1 / coefficients.rp)) * |
| 178 | (coefficients.rw / coefficients.rb2 * |
| 179 | (-Fwx))) / (coefficients.Jsm + |
| 180 | (coefficients.Js / |
| 181 | (coefficients.Gs * coefficients.Gs))) |
| 182 | |
| 183 | # Then solve for alphad |
| 184 | alphad = (coefficients.rs * coefficients.Jdm * coefficients.Gd3 * alphas + |
| 185 | coefficients.rp * coefficients.Ktd * Id * coefficients.Gd - |
| 186 | coefficients.rw * coefficients.rp * coefficients.Gd * Fwx * |
| 187 | coefficients.Gd) / (coefficients.rp * coefficients.Jdm) |
| 188 | |
| 189 | F = Rthetaplusthetas @ jax.numpy.array([Fwx, Fwy]) |
| 190 | |
justinT21 | d18f79f | 2024-09-22 19:43:05 -0700 | [diff] [blame] | 191 | torque = force_cross(Rtheta @ mounting_location, F) |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 192 | |
| 193 | X_dot_contribution = jax.numpy.hstack((jax.numpy.zeros( |
| 194 | (4, )), ) * (module_index) + (jax.numpy.array([ |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame^] | 195 | X_module[STATE_OMEGAS0], |
| 196 | X_module[STATE_OMEGAD0], |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 197 | alphas, |
| 198 | alphad, |
| 199 | ]), ) + (jax.numpy.zeros((4, )), ) * (3 - module_index) + ( |
| 200 | jax.numpy.zeros((3, )), |
| 201 | F / coefficients.m, |
| 202 | jax.numpy.array([torque / coefficients.J, 0, 0, 0]), |
| 203 | )) |
| 204 | |
| 205 | return X_dot_contribution |
| 206 | |
| 207 | |
| 208 | @partial(jax.jit, static_argnames=['coefficients']) |
| 209 | def full_dynamics(coefficients: CoefficientsType, X, U): |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame^] | 210 | Rtheta = R(X[STATE_THETA]) |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 211 | |
| 212 | module0 = full_module_physics( |
| 213 | coefficients, Rtheta, 0, |
| 214 | jax.numpy.array( |
| 215 | [coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]), |
| 216 | X, U) |
| 217 | module1 = full_module_physics( |
| 218 | coefficients, Rtheta, 1, |
| 219 | jax.numpy.array( |
| 220 | [-coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]), |
| 221 | X, U) |
| 222 | module2 = full_module_physics( |
| 223 | coefficients, Rtheta, 2, |
| 224 | jax.numpy.array( |
| 225 | [-coefficients.robot_width / 2.0, |
| 226 | -coefficients.robot_width / 2.0]), X, U) |
| 227 | module3 = full_module_physics( |
| 228 | coefficients, Rtheta, 3, |
| 229 | jax.numpy.array( |
| 230 | [coefficients.robot_width / 2.0, -coefficients.robot_width / 2.0]), |
| 231 | X, U) |
| 232 | |
| 233 | X_dot = module0 + module1 + module2 + module3 |
| 234 | |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame^] | 235 | X_dot = X_dot.at[STATE_X:STATE_THETA + 1].set( |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 236 | jax.numpy.array([ |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame^] | 237 | X[STATE_VX], |
| 238 | X[STATE_VY], |
| 239 | X[STATE_OMEGA], |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 240 | ])) |
| 241 | |
| 242 | return X_dot |
| 243 | |
| 244 | |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame^] | 245 | def velocity_module_physics(coefficients: CoefficientsType, |
| 246 | Rtheta: jax.typing.ArrayLike, module_index: int, |
| 247 | mounting_location: jax.typing.ArrayLike, |
| 248 | X: jax.typing.ArrayLike, U: jax.typing.ArrayLike): |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 249 | X_module = X[module_index * 2:(module_index + 1) * 2] |
| 250 | Is = U[2 * module_index + 0] |
| 251 | Id = U[2 * module_index + 1] |
| 252 | |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame^] | 253 | rotated_mounting_location = Rtheta @ mounting_location |
| 254 | |
| 255 | Rthetaplusthetas = R(X[VELOCITY_STATE_THETA] + |
| 256 | X_module[VELOCITY_STATE_THETAS0]) |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 257 | |
| 258 | caster_vector = jax.numpy.array([-coefficients.caster, 0.0]) |
| 259 | |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame^] | 260 | robot_velocity = X[VELOCITY_STATE_VX:VELOCITY_STATE_VY + 1] |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 261 | |
| 262 | contact_patch_velocity = ( |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame^] | 263 | angle_cross(rotated_mounting_location, X[VELOCITY_STATE_OMEGA]) + |
| 264 | robot_velocity + angle_cross( |
| 265 | Rthetaplusthetas @ caster_vector, |
| 266 | (X[VELOCITY_STATE_OMEGA] + X_module[VELOCITY_STATE_OMEGAS0]))) |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 267 | |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame^] | 268 | # Velocity of the contact patch over the field projected into the direction |
| 269 | # of the wheel. |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 270 | wheel_ground_velocity = Rthetaplusthetas.T @ contact_patch_velocity |
| 271 | |
| 272 | slip_angle = jax.numpy.sin( |
| 273 | -soft_atan2(wheel_ground_velocity[1], wheel_ground_velocity[0])) |
| 274 | |
| 275 | Fwx = (coefficients.Ktd / (coefficients.Gd * coefficients.rw)) * Id |
| 276 | Fwy = coefficients.Cy * slip_angle |
| 277 | |
Austin Schuh | ffb6db9 | 2024-09-04 14:00:48 -0700 | [diff] [blame] | 278 | softsign_velocity = softsign(wheel_ground_velocity[0], 100.0) |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 279 | |
| 280 | Ms = -Fwy * ( |
| 281 | (softsign_velocity * coefficients.contact_patch_length / 3.0) + |
| 282 | coefficients.caster) |
| 283 | |
| 284 | alphas = (Ms + coefficients.Kts * Is / coefficients.Gs + |
| 285 | (-coefficients.wb + (coefficients.rs + coefficients.rp) * |
| 286 | (1 - coefficients.rb1 / coefficients.rp)) * |
| 287 | (coefficients.rw / coefficients.rb2 * |
| 288 | (-Fwx))) / (coefficients.Jsm + |
| 289 | (coefficients.Js / |
| 290 | (coefficients.Gs * coefficients.Gs))) |
| 291 | |
| 292 | F = Rthetaplusthetas @ jax.numpy.array([Fwx, Fwy]) |
| 293 | |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame^] | 294 | torque = force_cross(rotated_mounting_location, F) |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 295 | |
| 296 | X_dot_contribution = jax.numpy.hstack((jax.numpy.zeros( |
| 297 | (2, )), ) * (module_index) + (jax.numpy.array([ |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame^] | 298 | X_module[VELOCITY_STATE_OMEGAS0], |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 299 | alphas, |
| 300 | ]), ) + (jax.numpy.zeros((2, )), ) * (3 - module_index) + ( |
| 301 | jax.numpy.zeros((1, )), |
| 302 | F / coefficients.m, |
| 303 | jax.numpy.array([torque / coefficients.J]), |
| 304 | )) |
| 305 | |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame^] | 306 | return X_dot_contribution, F, torque |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 307 | |
| 308 | |
| 309 | @partial(jax.jit, static_argnames=['coefficients']) |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame^] | 310 | def velocity_dynamics(coefficients: CoefficientsType, X: jax.typing.ArrayLike, |
| 311 | U: jax.typing.ArrayLike): |
| 312 | Rtheta = R(X[VELOCITY_STATE_THETA]) |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 313 | |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame^] | 314 | module0, _, _ = velocity_module_physics( |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 315 | coefficients, Rtheta, 0, |
| 316 | jax.numpy.array( |
| 317 | [coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]), |
| 318 | X, U) |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame^] | 319 | module1, _, _ = velocity_module_physics( |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 320 | coefficients, Rtheta, 1, |
| 321 | jax.numpy.array( |
| 322 | [-coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]), |
| 323 | X, U) |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame^] | 324 | module2, _, _ = velocity_module_physics( |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 325 | coefficients, Rtheta, 2, |
| 326 | jax.numpy.array( |
| 327 | [-coefficients.robot_width / 2.0, |
| 328 | -coefficients.robot_width / 2.0]), X, U) |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame^] | 329 | module3, _, _ = velocity_module_physics( |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 330 | coefficients, Rtheta, 3, |
| 331 | jax.numpy.array( |
| 332 | [coefficients.robot_width / 2.0, -coefficients.robot_width / 2.0]), |
| 333 | X, U) |
| 334 | |
| 335 | X_dot = module0 + module1 + module2 + module3 |
| 336 | |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame^] | 337 | return X_dot.at[VELOCITY_STATE_THETA].set(X[VELOCITY_STATE_OMEGA]) |
| 338 | |
| 339 | |
| 340 | def to_velocity_state(X): |
| 341 | return jax.numpy.array([ |
| 342 | X[STATE_THETAS0], |
| 343 | X[STATE_OMEGAS0], |
| 344 | X[STATE_THETAS1], |
| 345 | X[STATE_OMEGAS1], |
| 346 | X[STATE_THETAS2], |
| 347 | X[STATE_OMEGAS2], |
| 348 | X[STATE_THETAS3], |
| 349 | X[STATE_OMEGAS3], |
| 350 | X[STATE_THETA], |
| 351 | X[STATE_VX], |
| 352 | X[STATE_VY], |
| 353 | X[STATE_OMEGA], |
| 354 | ]) |