Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 1 | #!/usr/bin/env python3 |
| 2 | |
| 3 | from functools import partial |
| 4 | from collections import namedtuple |
| 5 | import jax |
| 6 | |
| 7 | from frc971.control_loops.swerve import dynamics |
| 8 | from frc971.control_loops.python.control_loop import KrakenFOC |
| 9 | |
| 10 | # Note: this physics needs to match the symengine code. We have tests that |
| 11 | # confirm they match in all the cases we care about. |
| 12 | |
| 13 | CoefficientsType = namedtuple('CoefficientsType', [ |
| 14 | 'Cx', |
| 15 | 'Cy', |
| 16 | 'rw', |
| 17 | 'm', |
| 18 | 'J', |
| 19 | 'Gd1', |
| 20 | 'rs', |
| 21 | 'rp', |
| 22 | 'Gd2', |
| 23 | 'rb1', |
| 24 | 'rb2', |
| 25 | 'Gd3', |
| 26 | 'Gd', |
| 27 | 'Js', |
| 28 | 'Gs', |
| 29 | 'wb', |
| 30 | 'Jdm', |
| 31 | 'Jsm', |
| 32 | 'Kts', |
| 33 | 'Ktd', |
| 34 | 'robot_width', |
| 35 | 'caster', |
| 36 | 'contact_patch_length', |
| 37 | ]) |
| 38 | |
| 39 | |
| 40 | def Coefficients( |
| 41 | Cx: float = 25.0 * 9.8 / 4.0 / 0.05, |
| 42 | Cy: float = 5 * 9.8 / 0.05 / 4.0, |
| 43 | rw: float = 2 * 0.0254, |
| 44 | |
| 45 | # base is 20 kg without battery |
| 46 | m: float = 25.0, |
| 47 | J: float = 6.0, |
| 48 | Gd1: float = 12.0 / 42.0, |
| 49 | rs: float = 28.0 / 20.0 / 2.0, |
| 50 | rp: float = 18.0 / 20.0 / 2.0, |
| 51 | |
| 52 | # 15 / 45 bevel ratio, calculated using python script ported over to |
| 53 | # GetBevelPitchRadius(double) |
| 54 | # TODO(Justin): Use the function instead of computed constantss |
| 55 | rb1: float = 0.3805473, |
| 56 | rb2: float = 1.14164, |
| 57 | Js: float = 0.001, |
| 58 | Gs: float = 35.0 / 468.0, |
| 59 | wb: float = 0.725, |
| 60 | drive_motor=KrakenFOC(), |
| 61 | steer_motor=KrakenFOC(), |
| 62 | robot_width: float = 24.75 * 0.0254, |
| 63 | caster: float = 0.01, |
| 64 | contact_patch_length: float = 0.02, |
| 65 | ) -> CoefficientsType: |
| 66 | |
| 67 | Gd2 = rs / rp |
| 68 | Gd3 = rb1 / rb2 |
| 69 | Gd = Gd1 * Gd2 * Gd3 |
| 70 | |
| 71 | Jdm = drive_motor.motor_inertia |
| 72 | Jsm = steer_motor.motor_inertia |
| 73 | Kts = steer_motor.Kt |
| 74 | Ktd = drive_motor.Kt |
| 75 | |
| 76 | return CoefficientsType( |
| 77 | Cx=Cx, |
| 78 | Cy=Cy, |
| 79 | rw=rw, |
| 80 | m=m, |
| 81 | J=J, |
| 82 | Gd1=Gd1, |
| 83 | rs=rs, |
| 84 | rp=rp, |
| 85 | Gd2=Gd2, |
| 86 | rb1=rb1, |
| 87 | rb2=rb2, |
| 88 | Gd3=Gd3, |
| 89 | Gd=Gd, |
| 90 | Js=Js, |
| 91 | Gs=Gs, |
| 92 | wb=wb, |
| 93 | Jdm=Jdm, |
| 94 | Jsm=Jsm, |
| 95 | Kts=Kts, |
| 96 | Ktd=Ktd, |
| 97 | robot_width=robot_width, |
| 98 | caster=caster, |
| 99 | contact_patch_length=contact_patch_length, |
| 100 | ) |
| 101 | |
| 102 | |
| 103 | def R(theta): |
| 104 | stheta = jax.numpy.sin(theta) |
| 105 | ctheta = jax.numpy.cos(theta) |
| 106 | return jax.numpy.array([[ctheta, -stheta], [stheta, ctheta]]) |
| 107 | |
| 108 | |
| 109 | def angle_cross(vector, omega): |
| 110 | return jax.numpy.array([-vector[1] * omega, vector[0] * omega]) |
| 111 | |
| 112 | |
| 113 | def force_cross(r, f): |
| 114 | return r[0] * f[1] - r[1] * f[0] |
| 115 | |
| 116 | |
| 117 | def softsign(x, gain): |
Austin Schuh | ffb6db9 | 2024-09-04 14:00:48 -0700 | [diff] [blame^] | 118 | return 1 - 2.0 * jax.nn.sigmoid(-gain * x) |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 119 | |
| 120 | |
| 121 | def soft_atan2(y, x): |
| 122 | kMaxLogGain = 1.0 / 0.05 |
| 123 | kAbsLogGain = 1.0 / 0.01 |
| 124 | |
Austin Schuh | ffb6db9 | 2024-09-04 14:00:48 -0700 | [diff] [blame^] | 125 | softabs_x = x * softsign(x, kAbsLogGain) |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 126 | |
| 127 | return jax.numpy.arctan2( |
| 128 | y, |
| 129 | jax.scipy.special.logsumexp( |
| 130 | jax.numpy.array([1.0, softabs_x * kMaxLogGain])) / kMaxLogGain) |
| 131 | |
| 132 | |
| 133 | def full_module_physics(coefficients: dict, Rtheta, module_index: int, |
| 134 | mounting_location, X, U): |
| 135 | X_module = X[module_index * 4:(module_index + 1) * 4] |
| 136 | Is = U[2 * module_index + 0] |
| 137 | Id = U[2 * module_index + 1] |
| 138 | |
| 139 | Rthetaplusthetas = R(X[dynamics.STATE_THETA] + |
| 140 | X_module[dynamics.STATE_THETAS0]) |
| 141 | |
| 142 | caster_vector = jax.numpy.array([-coefficients.caster, 0.0]) |
| 143 | |
| 144 | robot_velocity = X[dynamics.STATE_VX:dynamics.STATE_VY + 1] |
| 145 | |
| 146 | contact_patch_velocity = ( |
| 147 | angle_cross(Rtheta @ mounting_location, X[dynamics.STATE_OMEGA]) + |
| 148 | robot_velocity + angle_cross( |
| 149 | Rthetaplusthetas @ caster_vector, |
| 150 | (X[dynamics.STATE_OMEGA] + X_module[dynamics.STATE_OMEGAS0]))) |
| 151 | |
| 152 | wheel_ground_velocity = Rthetaplusthetas.T @ contact_patch_velocity |
| 153 | |
| 154 | wheel_velocity = jax.numpy.array( |
| 155 | [coefficients.rw * X_module[dynamics.STATE_OMEGAD0], 0.0]) |
| 156 | |
| 157 | wheel_slip_velocity = wheel_velocity - wheel_ground_velocity |
| 158 | |
| 159 | slip_angle = jax.numpy.sin( |
| 160 | -soft_atan2(wheel_ground_velocity[1], wheel_ground_velocity[0])) |
| 161 | |
| 162 | slip_ratio = (coefficients.rw * X_module[dynamics.STATE_OMEGAD0] - |
| 163 | wheel_ground_velocity[0]) / jax.numpy.max( |
| 164 | jax.numpy.array( |
| 165 | [0.02, jax.numpy.abs(wheel_ground_velocity[0])])) |
| 166 | |
| 167 | Fwx = coefficients.Cx * slip_ratio |
| 168 | Fwy = coefficients.Cy * slip_angle |
| 169 | |
| 170 | softsign_velocity = softsign(wheel_ground_velocity[0], 100) |
| 171 | |
| 172 | Ms = -Fwy * ( |
| 173 | (softsign_velocity * coefficients.contact_patch_length / 3.0) + |
| 174 | coefficients.caster) |
| 175 | |
| 176 | alphas = (Ms + coefficients.Kts * Is / coefficients.Gs + |
| 177 | (-coefficients.wb + (coefficients.rs + coefficients.rp) * |
| 178 | (1 - coefficients.rb1 / coefficients.rp)) * |
| 179 | (coefficients.rw / coefficients.rb2 * |
| 180 | (-Fwx))) / (coefficients.Jsm + |
| 181 | (coefficients.Js / |
| 182 | (coefficients.Gs * coefficients.Gs))) |
| 183 | |
| 184 | # Then solve for alphad |
| 185 | alphad = (coefficients.rs * coefficients.Jdm * coefficients.Gd3 * alphas + |
| 186 | coefficients.rp * coefficients.Ktd * Id * coefficients.Gd - |
| 187 | coefficients.rw * coefficients.rp * coefficients.Gd * Fwx * |
| 188 | coefficients.Gd) / (coefficients.rp * coefficients.Jdm) |
| 189 | |
| 190 | F = Rthetaplusthetas @ jax.numpy.array([Fwx, Fwy]) |
| 191 | |
| 192 | torque = force_cross(mounting_location, F) |
| 193 | |
| 194 | X_dot_contribution = jax.numpy.hstack((jax.numpy.zeros( |
| 195 | (4, )), ) * (module_index) + (jax.numpy.array([ |
| 196 | X_module[dynamics.STATE_OMEGAS0], |
| 197 | X_module[dynamics.STATE_OMEGAD0], |
| 198 | alphas, |
| 199 | alphad, |
| 200 | ]), ) + (jax.numpy.zeros((4, )), ) * (3 - module_index) + ( |
| 201 | jax.numpy.zeros((3, )), |
| 202 | F / coefficients.m, |
| 203 | jax.numpy.array([torque / coefficients.J, 0, 0, 0]), |
| 204 | )) |
| 205 | |
| 206 | return X_dot_contribution |
| 207 | |
| 208 | |
| 209 | @partial(jax.jit, static_argnames=['coefficients']) |
| 210 | def full_dynamics(coefficients: CoefficientsType, X, U): |
| 211 | Rtheta = R(X[dynamics.STATE_THETA]) |
| 212 | |
| 213 | module0 = full_module_physics( |
| 214 | coefficients, Rtheta, 0, |
| 215 | jax.numpy.array( |
| 216 | [coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]), |
| 217 | X, U) |
| 218 | module1 = full_module_physics( |
| 219 | coefficients, Rtheta, 1, |
| 220 | jax.numpy.array( |
| 221 | [-coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]), |
| 222 | X, U) |
| 223 | module2 = full_module_physics( |
| 224 | coefficients, Rtheta, 2, |
| 225 | jax.numpy.array( |
| 226 | [-coefficients.robot_width / 2.0, |
| 227 | -coefficients.robot_width / 2.0]), X, U) |
| 228 | module3 = full_module_physics( |
| 229 | coefficients, Rtheta, 3, |
| 230 | jax.numpy.array( |
| 231 | [coefficients.robot_width / 2.0, -coefficients.robot_width / 2.0]), |
| 232 | X, U) |
| 233 | |
| 234 | X_dot = module0 + module1 + module2 + module3 |
| 235 | |
| 236 | X_dot = X_dot.at[dynamics.STATE_X:dynamics.STATE_THETA + 1].set( |
| 237 | jax.numpy.array([ |
| 238 | X[dynamics.STATE_VX], |
| 239 | X[dynamics.STATE_VY], |
| 240 | X[dynamics.STATE_OMEGA], |
| 241 | ])) |
| 242 | |
| 243 | return X_dot |
| 244 | |
| 245 | |
| 246 | def velocity_module_physics(coefficients: dict, Rtheta, module_index: int, |
| 247 | mounting_location, X, U): |
| 248 | X_module = X[module_index * 2:(module_index + 1) * 2] |
| 249 | Is = U[2 * module_index + 0] |
| 250 | Id = U[2 * module_index + 1] |
| 251 | |
| 252 | Rthetaplusthetas = R(X[dynamics.VELOCITY_STATE_THETA] + |
| 253 | X_module[dynamics.VELOCITY_STATE_THETAS0]) |
| 254 | |
| 255 | caster_vector = jax.numpy.array([-coefficients.caster, 0.0]) |
| 256 | |
| 257 | robot_velocity = X[dynamics.VELOCITY_STATE_VX:dynamics.VELOCITY_STATE_VY + |
| 258 | 1] |
| 259 | |
| 260 | contact_patch_velocity = ( |
| 261 | angle_cross(Rtheta @ mounting_location, |
| 262 | X[dynamics.VELOCITY_STATE_OMEGA]) + robot_velocity + |
| 263 | angle_cross(Rthetaplusthetas @ caster_vector, |
| 264 | (X[dynamics.VELOCITY_STATE_OMEGA] + |
| 265 | X_module[dynamics.VELOCITY_STATE_OMEGAS0]))) |
| 266 | |
| 267 | wheel_ground_velocity = Rthetaplusthetas.T @ contact_patch_velocity |
| 268 | |
| 269 | slip_angle = jax.numpy.sin( |
| 270 | -soft_atan2(wheel_ground_velocity[1], wheel_ground_velocity[0])) |
| 271 | |
| 272 | Fwx = (coefficients.Ktd / (coefficients.Gd * coefficients.rw)) * Id |
| 273 | Fwy = coefficients.Cy * slip_angle |
| 274 | |
Austin Schuh | ffb6db9 | 2024-09-04 14:00:48 -0700 | [diff] [blame^] | 275 | softsign_velocity = softsign(wheel_ground_velocity[0], 100.0) |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 276 | |
| 277 | Ms = -Fwy * ( |
| 278 | (softsign_velocity * coefficients.contact_patch_length / 3.0) + |
| 279 | coefficients.caster) |
| 280 | |
| 281 | alphas = (Ms + coefficients.Kts * Is / coefficients.Gs + |
| 282 | (-coefficients.wb + (coefficients.rs + coefficients.rp) * |
| 283 | (1 - coefficients.rb1 / coefficients.rp)) * |
| 284 | (coefficients.rw / coefficients.rb2 * |
| 285 | (-Fwx))) / (coefficients.Jsm + |
| 286 | (coefficients.Js / |
| 287 | (coefficients.Gs * coefficients.Gs))) |
| 288 | |
| 289 | F = Rthetaplusthetas @ jax.numpy.array([Fwx, Fwy]) |
| 290 | |
| 291 | torque = force_cross(mounting_location, F) |
| 292 | |
| 293 | X_dot_contribution = jax.numpy.hstack((jax.numpy.zeros( |
| 294 | (2, )), ) * (module_index) + (jax.numpy.array([ |
| 295 | X_module[dynamics.VELOCITY_STATE_OMEGAS0], |
| 296 | alphas, |
| 297 | ]), ) + (jax.numpy.zeros((2, )), ) * (3 - module_index) + ( |
| 298 | jax.numpy.zeros((1, )), |
| 299 | F / coefficients.m, |
| 300 | jax.numpy.array([torque / coefficients.J]), |
| 301 | )) |
| 302 | |
| 303 | return X_dot_contribution |
| 304 | |
| 305 | |
| 306 | @partial(jax.jit, static_argnames=['coefficients']) |
| 307 | def velocity_dynamics(coefficients: CoefficientsType, X, U): |
| 308 | Rtheta = R(X[dynamics.VELOCITY_STATE_THETA]) |
| 309 | |
| 310 | module0 = velocity_module_physics( |
| 311 | coefficients, Rtheta, 0, |
| 312 | jax.numpy.array( |
| 313 | [coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]), |
| 314 | X, U) |
| 315 | module1 = velocity_module_physics( |
| 316 | coefficients, Rtheta, 1, |
| 317 | jax.numpy.array( |
| 318 | [-coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]), |
| 319 | X, U) |
| 320 | module2 = velocity_module_physics( |
| 321 | coefficients, Rtheta, 2, |
| 322 | jax.numpy.array( |
| 323 | [-coefficients.robot_width / 2.0, |
| 324 | -coefficients.robot_width / 2.0]), X, U) |
| 325 | module3 = velocity_module_physics( |
| 326 | coefficients, Rtheta, 3, |
| 327 | jax.numpy.array( |
| 328 | [coefficients.robot_width / 2.0, -coefficients.robot_width / 2.0]), |
| 329 | X, U) |
| 330 | |
| 331 | X_dot = module0 + module1 + module2 + module3 |
| 332 | |
| 333 | return X_dot.at[dynamics.VELOCITY_STATE_THETA].set( |
| 334 | X[dynamics.VELOCITY_STATE_OMEGA]) |