Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 1 | #!/usr/bin/env python3 |
| 2 | |
| 3 | from functools import partial |
| 4 | from collections import namedtuple |
| 5 | import jax |
| 6 | |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 7 | from frc971.control_loops.python.control_loop import KrakenFOC |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame] | 8 | from frc971.control_loops.swerve.dynamics_constants import * |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 9 | |
| 10 | # Note: this physics needs to match the symengine code. We have tests that |
| 11 | # confirm they match in all the cases we care about. |
| 12 | |
| 13 | CoefficientsType = namedtuple('CoefficientsType', [ |
| 14 | 'Cx', |
| 15 | 'Cy', |
| 16 | 'rw', |
| 17 | 'm', |
| 18 | 'J', |
| 19 | 'Gd1', |
| 20 | 'rs', |
| 21 | 'rp', |
| 22 | 'Gd2', |
| 23 | 'rb1', |
| 24 | 'rb2', |
| 25 | 'Gd3', |
| 26 | 'Gd', |
| 27 | 'Js', |
| 28 | 'Gs', |
| 29 | 'wb', |
| 30 | 'Jdm', |
| 31 | 'Jsm', |
| 32 | 'Kts', |
| 33 | 'Ktd', |
| 34 | 'robot_width', |
| 35 | 'caster', |
| 36 | 'contact_patch_length', |
| 37 | ]) |
| 38 | |
| 39 | |
| 40 | def Coefficients( |
| 41 | Cx: float = 25.0 * 9.8 / 4.0 / 0.05, |
| 42 | Cy: float = 5 * 9.8 / 0.05 / 4.0, |
| 43 | rw: float = 2 * 0.0254, |
| 44 | |
| 45 | # base is 20 kg without battery |
| 46 | m: float = 25.0, |
| 47 | J: float = 6.0, |
| 48 | Gd1: float = 12.0 / 42.0, |
| 49 | rs: float = 28.0 / 20.0 / 2.0, |
| 50 | rp: float = 18.0 / 20.0 / 2.0, |
| 51 | |
| 52 | # 15 / 45 bevel ratio, calculated using python script ported over to |
| 53 | # GetBevelPitchRadius(double) |
| 54 | # TODO(Justin): Use the function instead of computed constantss |
| 55 | rb1: float = 0.3805473, |
| 56 | rb2: float = 1.14164, |
| 57 | Js: float = 0.001, |
| 58 | Gs: float = 35.0 / 468.0, |
| 59 | wb: float = 0.725, |
| 60 | drive_motor=KrakenFOC(), |
| 61 | steer_motor=KrakenFOC(), |
| 62 | robot_width: float = 24.75 * 0.0254, |
| 63 | caster: float = 0.01, |
| 64 | contact_patch_length: float = 0.02, |
| 65 | ) -> CoefficientsType: |
| 66 | |
| 67 | Gd2 = rs / rp |
| 68 | Gd3 = rb1 / rb2 |
| 69 | Gd = Gd1 * Gd2 * Gd3 |
| 70 | |
| 71 | Jdm = drive_motor.motor_inertia |
| 72 | Jsm = steer_motor.motor_inertia |
| 73 | Kts = steer_motor.Kt |
| 74 | Ktd = drive_motor.Kt |
| 75 | |
| 76 | return CoefficientsType( |
| 77 | Cx=Cx, |
| 78 | Cy=Cy, |
| 79 | rw=rw, |
| 80 | m=m, |
| 81 | J=J, |
| 82 | Gd1=Gd1, |
| 83 | rs=rs, |
| 84 | rp=rp, |
| 85 | Gd2=Gd2, |
| 86 | rb1=rb1, |
| 87 | rb2=rb2, |
| 88 | Gd3=Gd3, |
| 89 | Gd=Gd, |
| 90 | Js=Js, |
| 91 | Gs=Gs, |
| 92 | wb=wb, |
| 93 | Jdm=Jdm, |
| 94 | Jsm=Jsm, |
| 95 | Kts=Kts, |
| 96 | Ktd=Ktd, |
| 97 | robot_width=robot_width, |
| 98 | caster=caster, |
| 99 | contact_patch_length=contact_patch_length, |
| 100 | ) |
| 101 | |
| 102 | |
| 103 | def R(theta): |
| 104 | stheta = jax.numpy.sin(theta) |
| 105 | ctheta = jax.numpy.cos(theta) |
| 106 | return jax.numpy.array([[ctheta, -stheta], [stheta, ctheta]]) |
| 107 | |
| 108 | |
| 109 | def angle_cross(vector, omega): |
| 110 | return jax.numpy.array([-vector[1] * omega, vector[0] * omega]) |
| 111 | |
| 112 | |
| 113 | def force_cross(r, f): |
| 114 | return r[0] * f[1] - r[1] * f[0] |
| 115 | |
| 116 | |
| 117 | def softsign(x, gain): |
Austin Schuh | ffb6db9 | 2024-09-04 14:00:48 -0700 | [diff] [blame] | 118 | return 1 - 2.0 * jax.nn.sigmoid(-gain * x) |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 119 | |
| 120 | |
| 121 | def soft_atan2(y, x): |
| 122 | kMaxLogGain = 1.0 / 0.05 |
| 123 | kAbsLogGain = 1.0 / 0.01 |
| 124 | |
Austin Schuh | ffb6db9 | 2024-09-04 14:00:48 -0700 | [diff] [blame] | 125 | softabs_x = x * softsign(x, kAbsLogGain) |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 126 | |
| 127 | return jax.numpy.arctan2( |
| 128 | y, |
| 129 | jax.scipy.special.logsumexp( |
| 130 | jax.numpy.array([1.0, softabs_x * kMaxLogGain])) / kMaxLogGain) |
| 131 | |
| 132 | |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame] | 133 | def full_module_physics(coefficients: CoefficientsType, Rtheta, |
| 134 | module_index: int, mounting_location, X, U): |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 135 | X_module = X[module_index * 4:(module_index + 1) * 4] |
| 136 | Is = U[2 * module_index + 0] |
| 137 | Id = U[2 * module_index + 1] |
| 138 | |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame] | 139 | Rthetaplusthetas = R(X[STATE_THETA] + X_module[STATE_THETAS0]) |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 140 | |
| 141 | caster_vector = jax.numpy.array([-coefficients.caster, 0.0]) |
| 142 | |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame] | 143 | robot_velocity = X[STATE_VX:STATE_VY + 1] |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 144 | |
| 145 | contact_patch_velocity = ( |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame] | 146 | angle_cross(Rtheta @ mounting_location, X[STATE_OMEGA]) + |
| 147 | robot_velocity + |
| 148 | angle_cross(Rthetaplusthetas @ caster_vector, |
| 149 | (X[STATE_OMEGA] + X_module[STATE_OMEGAS0]))) |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 150 | |
| 151 | wheel_ground_velocity = Rthetaplusthetas.T @ contact_patch_velocity |
| 152 | |
| 153 | wheel_velocity = jax.numpy.array( |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame] | 154 | [coefficients.rw * X_module[STATE_OMEGAD0], 0.0]) |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 155 | |
| 156 | wheel_slip_velocity = wheel_velocity - wheel_ground_velocity |
| 157 | |
| 158 | slip_angle = jax.numpy.sin( |
| 159 | -soft_atan2(wheel_ground_velocity[1], wheel_ground_velocity[0])) |
| 160 | |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame] | 161 | slip_ratio = (coefficients.rw * X_module[STATE_OMEGAD0] - |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 162 | wheel_ground_velocity[0]) / jax.numpy.max( |
| 163 | jax.numpy.array( |
| 164 | [0.02, jax.numpy.abs(wheel_ground_velocity[0])])) |
| 165 | |
| 166 | Fwx = coefficients.Cx * slip_ratio |
| 167 | Fwy = coefficients.Cy * slip_angle |
| 168 | |
| 169 | softsign_velocity = softsign(wheel_ground_velocity[0], 100) |
| 170 | |
| 171 | Ms = -Fwy * ( |
| 172 | (softsign_velocity * coefficients.contact_patch_length / 3.0) + |
| 173 | coefficients.caster) |
| 174 | |
| 175 | alphas = (Ms + coefficients.Kts * Is / coefficients.Gs + |
| 176 | (-coefficients.wb + (coefficients.rs + coefficients.rp) * |
| 177 | (1 - coefficients.rb1 / coefficients.rp)) * |
| 178 | (coefficients.rw / coefficients.rb2 * |
| 179 | (-Fwx))) / (coefficients.Jsm + |
| 180 | (coefficients.Js / |
| 181 | (coefficients.Gs * coefficients.Gs))) |
| 182 | |
| 183 | # Then solve for alphad |
| 184 | alphad = (coefficients.rs * coefficients.Jdm * coefficients.Gd3 * alphas + |
| 185 | coefficients.rp * coefficients.Ktd * Id * coefficients.Gd - |
| 186 | coefficients.rw * coefficients.rp * coefficients.Gd * Fwx * |
| 187 | coefficients.Gd) / (coefficients.rp * coefficients.Jdm) |
| 188 | |
| 189 | F = Rthetaplusthetas @ jax.numpy.array([Fwx, Fwy]) |
| 190 | |
justinT21 | d18f79f | 2024-09-22 19:43:05 -0700 | [diff] [blame] | 191 | torque = force_cross(Rtheta @ mounting_location, F) |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 192 | |
| 193 | X_dot_contribution = jax.numpy.hstack((jax.numpy.zeros( |
| 194 | (4, )), ) * (module_index) + (jax.numpy.array([ |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame] | 195 | X_module[STATE_OMEGAS0], |
| 196 | X_module[STATE_OMEGAD0], |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 197 | alphas, |
| 198 | alphad, |
| 199 | ]), ) + (jax.numpy.zeros((4, )), ) * (3 - module_index) + ( |
| 200 | jax.numpy.zeros((3, )), |
| 201 | F / coefficients.m, |
| 202 | jax.numpy.array([torque / coefficients.J, 0, 0, 0]), |
| 203 | )) |
| 204 | |
| 205 | return X_dot_contribution |
| 206 | |
| 207 | |
| 208 | @partial(jax.jit, static_argnames=['coefficients']) |
| 209 | def full_dynamics(coefficients: CoefficientsType, X, U): |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame] | 210 | Rtheta = R(X[STATE_THETA]) |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 211 | |
| 212 | module0 = full_module_physics( |
| 213 | coefficients, Rtheta, 0, |
| 214 | jax.numpy.array( |
| 215 | [coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]), |
| 216 | X, U) |
| 217 | module1 = full_module_physics( |
| 218 | coefficients, Rtheta, 1, |
| 219 | jax.numpy.array( |
| 220 | [-coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]), |
| 221 | X, U) |
| 222 | module2 = full_module_physics( |
| 223 | coefficients, Rtheta, 2, |
| 224 | jax.numpy.array( |
| 225 | [-coefficients.robot_width / 2.0, |
| 226 | -coefficients.robot_width / 2.0]), X, U) |
| 227 | module3 = full_module_physics( |
| 228 | coefficients, Rtheta, 3, |
| 229 | jax.numpy.array( |
| 230 | [coefficients.robot_width / 2.0, -coefficients.robot_width / 2.0]), |
| 231 | X, U) |
| 232 | |
| 233 | X_dot = module0 + module1 + module2 + module3 |
| 234 | |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame] | 235 | X_dot = X_dot.at[STATE_X:STATE_THETA + 1].set( |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 236 | jax.numpy.array([ |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame] | 237 | X[STATE_VX], |
| 238 | X[STATE_VY], |
| 239 | X[STATE_OMEGA], |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 240 | ])) |
| 241 | |
| 242 | return X_dot |
| 243 | |
| 244 | |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame] | 245 | def velocity_module_physics(coefficients: CoefficientsType, |
| 246 | Rtheta: jax.typing.ArrayLike, module_index: int, |
| 247 | mounting_location: jax.typing.ArrayLike, |
| 248 | X: jax.typing.ArrayLike, U: jax.typing.ArrayLike): |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 249 | X_module = X[module_index * 2:(module_index + 1) * 2] |
| 250 | Is = U[2 * module_index + 0] |
| 251 | Id = U[2 * module_index + 1] |
| 252 | |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame] | 253 | rotated_mounting_location = Rtheta @ mounting_location |
| 254 | |
| 255 | Rthetaplusthetas = R(X[VELOCITY_STATE_THETA] + |
| 256 | X_module[VELOCITY_STATE_THETAS0]) |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 257 | |
| 258 | caster_vector = jax.numpy.array([-coefficients.caster, 0.0]) |
| 259 | |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame] | 260 | robot_velocity = X[VELOCITY_STATE_VX:VELOCITY_STATE_VY + 1] |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 261 | |
| 262 | contact_patch_velocity = ( |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame] | 263 | angle_cross(rotated_mounting_location, X[VELOCITY_STATE_OMEGA]) + |
| 264 | robot_velocity + angle_cross( |
| 265 | Rthetaplusthetas @ caster_vector, |
| 266 | (X[VELOCITY_STATE_OMEGA] + X_module[VELOCITY_STATE_OMEGAS0]))) |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 267 | |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame] | 268 | # Velocity of the contact patch over the field projected into the direction |
| 269 | # of the wheel. |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 270 | wheel_ground_velocity = Rthetaplusthetas.T @ contact_patch_velocity |
| 271 | |
| 272 | slip_angle = jax.numpy.sin( |
| 273 | -soft_atan2(wheel_ground_velocity[1], wheel_ground_velocity[0])) |
| 274 | |
| 275 | Fwx = (coefficients.Ktd / (coefficients.Gd * coefficients.rw)) * Id |
| 276 | Fwy = coefficients.Cy * slip_angle |
| 277 | |
Austin Schuh | ffb6db9 | 2024-09-04 14:00:48 -0700 | [diff] [blame] | 278 | softsign_velocity = softsign(wheel_ground_velocity[0], 100.0) |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 279 | |
| 280 | Ms = -Fwy * ( |
| 281 | (softsign_velocity * coefficients.contact_patch_length / 3.0) + |
| 282 | coefficients.caster) |
| 283 | |
| 284 | alphas = (Ms + coefficients.Kts * Is / coefficients.Gs + |
| 285 | (-coefficients.wb + (coefficients.rs + coefficients.rp) * |
| 286 | (1 - coefficients.rb1 / coefficients.rp)) * |
| 287 | (coefficients.rw / coefficients.rb2 * |
| 288 | (-Fwx))) / (coefficients.Jsm + |
| 289 | (coefficients.Js / |
| 290 | (coefficients.Gs * coefficients.Gs))) |
| 291 | |
| 292 | F = Rthetaplusthetas @ jax.numpy.array([Fwx, Fwy]) |
| 293 | |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame] | 294 | torque = force_cross(rotated_mounting_location, F) |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 295 | |
| 296 | X_dot_contribution = jax.numpy.hstack((jax.numpy.zeros( |
| 297 | (2, )), ) * (module_index) + (jax.numpy.array([ |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame] | 298 | X_module[VELOCITY_STATE_OMEGAS0], |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 299 | alphas, |
| 300 | ]), ) + (jax.numpy.zeros((2, )), ) * (3 - module_index) + ( |
| 301 | jax.numpy.zeros((1, )), |
| 302 | F / coefficients.m, |
| 303 | jax.numpy.array([torque / coefficients.J]), |
| 304 | )) |
| 305 | |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame] | 306 | return X_dot_contribution, F, torque |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 307 | |
| 308 | |
| 309 | @partial(jax.jit, static_argnames=['coefficients']) |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame] | 310 | def velocity_dynamics(coefficients: CoefficientsType, X: jax.typing.ArrayLike, |
| 311 | U: jax.typing.ArrayLike): |
| 312 | Rtheta = R(X[VELOCITY_STATE_THETA]) |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 313 | |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame] | 314 | module0, _, _ = velocity_module_physics( |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 315 | coefficients, Rtheta, 0, |
| 316 | jax.numpy.array( |
| 317 | [coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]), |
| 318 | X, U) |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame] | 319 | module1, _, _ = velocity_module_physics( |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 320 | coefficients, Rtheta, 1, |
| 321 | jax.numpy.array( |
| 322 | [-coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]), |
| 323 | X, U) |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame] | 324 | module2, _, _ = velocity_module_physics( |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 325 | coefficients, Rtheta, 2, |
| 326 | jax.numpy.array( |
| 327 | [-coefficients.robot_width / 2.0, |
| 328 | -coefficients.robot_width / 2.0]), X, U) |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame] | 329 | module3, _, _ = velocity_module_physics( |
Austin Schuh | 76534f3 | 2024-09-02 13:52:45 -0700 | [diff] [blame] | 330 | coefficients, Rtheta, 3, |
| 331 | jax.numpy.array( |
| 332 | [coefficients.robot_width / 2.0, -coefficients.robot_width / 2.0]), |
| 333 | X, U) |
| 334 | |
| 335 | X_dot = module0 + module1 + module2 + module3 |
| 336 | |
Austin Schuh | a9550c0 | 2024-10-19 13:48:10 -0700 | [diff] [blame] | 337 | return X_dot.at[VELOCITY_STATE_THETA].set(X[VELOCITY_STATE_OMEGA]) |
| 338 | |
| 339 | |
| 340 | def to_velocity_state(X): |
| 341 | return jax.numpy.array([ |
| 342 | X[STATE_THETAS0], |
| 343 | X[STATE_OMEGAS0], |
| 344 | X[STATE_THETAS1], |
| 345 | X[STATE_OMEGAS1], |
| 346 | X[STATE_THETAS2], |
| 347 | X[STATE_OMEGAS2], |
| 348 | X[STATE_THETAS3], |
| 349 | X[STATE_OMEGAS3], |
| 350 | X[STATE_THETA], |
| 351 | X[STATE_VX], |
| 352 | X[STATE_VY], |
| 353 | X[STATE_OMEGA], |
| 354 | ]) |
Austin Schuh | 5dac229 | 2024-10-19 13:56:58 -0700 | [diff] [blame] | 355 | |
| 356 | |
Austin Schuh | e0cf27d | 2024-10-26 22:22:20 -0700 | [diff] [blame^] | 357 | @jax.jit |
Austin Schuh | 5dac229 | 2024-10-19 13:56:58 -0700 | [diff] [blame] | 358 | def mpc_cost(coefficients: CoefficientsType, X, U, goal): |
| 359 | J = 0 |
| 360 | |
| 361 | rnorm = jax.numpy.linalg.norm(goal[0:2]) |
| 362 | |
| 363 | vnorm = jax.lax.select(rnorm > 0.0001, goal[0:2] / rnorm, |
| 364 | jax.numpy.array([1.0, 0.0])) |
| 365 | vperp = jax.lax.select(rnorm > 0.0001, |
| 366 | jax.numpy.array([-vnorm[1], vnorm[0]]), |
| 367 | jax.numpy.array([0.0, 1.0])) |
| 368 | |
| 369 | velocity_error = goal[0:2] - X[VELOCITY_STATE_VX:VELOCITY_STATE_VY + 1] |
| 370 | |
| 371 | # TODO(austin): Do we want to do something more special for 0? |
| 372 | |
| 373 | J += 75 * (jax.numpy.dot(velocity_error, vnorm)**2.0) |
| 374 | J += 1500 * (jax.numpy.dot(velocity_error, vperp)**2.0) |
| 375 | J += 1000 * (goal[2] - X[VELOCITY_STATE_OMEGA])**2.0 |
| 376 | |
| 377 | kSteerVelocityGain = 0.10 |
| 378 | J += kSteerVelocityGain * (X[VELOCITY_STATE_OMEGAS0])**2.0 |
| 379 | J += kSteerVelocityGain * (X[VELOCITY_STATE_OMEGAS1])**2.0 |
| 380 | J += kSteerVelocityGain * (X[VELOCITY_STATE_OMEGAS2])**2.0 |
| 381 | J += kSteerVelocityGain * (X[VELOCITY_STATE_OMEGAS3])**2.0 |
| 382 | |
| 383 | mounting_locations = jax.numpy.array( |
| 384 | [[coefficients.robot_width / 2.0, coefficients.robot_width / 2.0], |
| 385 | [-coefficients.robot_width / 2.0, coefficients.robot_width / 2.0], |
| 386 | [-coefficients.robot_width / 2.0, -coefficients.robot_width / 2.0], |
| 387 | [coefficients.robot_width / 2.0, -coefficients.robot_width / 2.0]]) |
| 388 | |
| 389 | Rtheta = R(X[VELOCITY_STATE_THETA]) |
| 390 | _, F0, torque0 = velocity_module_physics(coefficients, Rtheta, 0, |
| 391 | mounting_locations[0], X, U) |
| 392 | _, F1, torque1 = velocity_module_physics(coefficients, Rtheta, 1, |
| 393 | mounting_locations[1], X, U) |
| 394 | _, F2, torque2 = velocity_module_physics(coefficients, Rtheta, 2, |
| 395 | mounting_locations[2], X, U) |
| 396 | _, F3, torque3 = velocity_module_physics(coefficients, Rtheta, 3, |
| 397 | mounting_locations[3], X, U) |
| 398 | |
| 399 | forces = [F0, F1, F2, F3] |
| 400 | |
| 401 | F = (F0 + F1 + F2 + F3) |
| 402 | torque = (torque0 + torque1 + torque2 + torque3) |
| 403 | |
| 404 | def force_cross(torque, r): |
| 405 | r_squared_norm = jax.numpy.inner(r, r) |
| 406 | |
| 407 | return jax.numpy.array( |
| 408 | [-r[1] * torque / r_squared_norm, r[0] * torque / r_squared_norm]) |
| 409 | |
| 410 | # TODO(austin): Are these penalties reasonable? Do they give us a decent time constant? |
| 411 | for i in range(4): |
| 412 | desired_force = F / 4.0 + force_cross( |
| 413 | torque / 4.0, Rtheta @ mounting_locations[i, :]) |
| 414 | force_error = desired_force - forces[i] |
| 415 | J += 0.01 * jax.numpy.inner(force_error, force_error) |
| 416 | |
| 417 | for i in range(4): |
| 418 | Is = U[2 * i + 0] |
| 419 | Id = U[2 * i + 1] |
| 420 | # Steer |
| 421 | J += ((Is + STEER_CURRENT_COUPLING_FACTOR * Id)**2.0) / 100000.0 |
| 422 | # Drive |
| 423 | J += (Id**2.0) / 1000.0 |
| 424 | |
| 425 | return J |