blob: 58d5fcf9f9f4e7f79d96fe2f230ee388332cef9b [file] [log] [blame]
Austin Schuh76534f32024-09-02 13:52:45 -07001#!/usr/bin/env python3
2
3from functools import partial
4from collections import namedtuple
5import jax
6
Austin Schuh76534f32024-09-02 13:52:45 -07007from frc971.control_loops.python.control_loop import KrakenFOC
Austin Schuha9550c02024-10-19 13:48:10 -07008from frc971.control_loops.swerve.dynamics_constants import *
Austin Schuh76534f32024-09-02 13:52:45 -07009
10# Note: this physics needs to match the symengine code. We have tests that
11# confirm they match in all the cases we care about.
12
13CoefficientsType = namedtuple('CoefficientsType', [
14 'Cx',
15 'Cy',
16 'rw',
17 'm',
18 'J',
19 'Gd1',
20 'rs',
21 'rp',
22 'Gd2',
23 'rb1',
24 'rb2',
25 'Gd3',
26 'Gd',
27 'Js',
28 'Gs',
29 'wb',
30 'Jdm',
31 'Jsm',
32 'Kts',
33 'Ktd',
34 'robot_width',
35 'caster',
36 'contact_patch_length',
37])
38
39
40def Coefficients(
41 Cx: float = 25.0 * 9.8 / 4.0 / 0.05,
42 Cy: float = 5 * 9.8 / 0.05 / 4.0,
43 rw: float = 2 * 0.0254,
44
45 # base is 20 kg without battery
46 m: float = 25.0,
47 J: float = 6.0,
48 Gd1: float = 12.0 / 42.0,
49 rs: float = 28.0 / 20.0 / 2.0,
50 rp: float = 18.0 / 20.0 / 2.0,
51
52 # 15 / 45 bevel ratio, calculated using python script ported over to
53 # GetBevelPitchRadius(double)
54 # TODO(Justin): Use the function instead of computed constantss
55 rb1: float = 0.3805473,
56 rb2: float = 1.14164,
57 Js: float = 0.001,
58 Gs: float = 35.0 / 468.0,
59 wb: float = 0.725,
60 drive_motor=KrakenFOC(),
61 steer_motor=KrakenFOC(),
62 robot_width: float = 24.75 * 0.0254,
63 caster: float = 0.01,
64 contact_patch_length: float = 0.02,
65) -> CoefficientsType:
66
67 Gd2 = rs / rp
68 Gd3 = rb1 / rb2
69 Gd = Gd1 * Gd2 * Gd3
70
71 Jdm = drive_motor.motor_inertia
72 Jsm = steer_motor.motor_inertia
73 Kts = steer_motor.Kt
74 Ktd = drive_motor.Kt
75
76 return CoefficientsType(
77 Cx=Cx,
78 Cy=Cy,
79 rw=rw,
80 m=m,
81 J=J,
82 Gd1=Gd1,
83 rs=rs,
84 rp=rp,
85 Gd2=Gd2,
86 rb1=rb1,
87 rb2=rb2,
88 Gd3=Gd3,
89 Gd=Gd,
90 Js=Js,
91 Gs=Gs,
92 wb=wb,
93 Jdm=Jdm,
94 Jsm=Jsm,
95 Kts=Kts,
96 Ktd=Ktd,
97 robot_width=robot_width,
98 caster=caster,
99 contact_patch_length=contact_patch_length,
100 )
101
102
103def R(theta):
104 stheta = jax.numpy.sin(theta)
105 ctheta = jax.numpy.cos(theta)
106 return jax.numpy.array([[ctheta, -stheta], [stheta, ctheta]])
107
108
109def angle_cross(vector, omega):
110 return jax.numpy.array([-vector[1] * omega, vector[0] * omega])
111
112
113def force_cross(r, f):
114 return r[0] * f[1] - r[1] * f[0]
115
116
117def softsign(x, gain):
Austin Schuhffb6db92024-09-04 14:00:48 -0700118 return 1 - 2.0 * jax.nn.sigmoid(-gain * x)
Austin Schuh76534f32024-09-02 13:52:45 -0700119
120
121def soft_atan2(y, x):
122 kMaxLogGain = 1.0 / 0.05
123 kAbsLogGain = 1.0 / 0.01
124
Austin Schuhffb6db92024-09-04 14:00:48 -0700125 softabs_x = x * softsign(x, kAbsLogGain)
Austin Schuh76534f32024-09-02 13:52:45 -0700126
127 return jax.numpy.arctan2(
128 y,
129 jax.scipy.special.logsumexp(
130 jax.numpy.array([1.0, softabs_x * kMaxLogGain])) / kMaxLogGain)
131
132
Austin Schuha9550c02024-10-19 13:48:10 -0700133def full_module_physics(coefficients: CoefficientsType, Rtheta,
134 module_index: int, mounting_location, X, U):
Austin Schuh76534f32024-09-02 13:52:45 -0700135 X_module = X[module_index * 4:(module_index + 1) * 4]
136 Is = U[2 * module_index + 0]
137 Id = U[2 * module_index + 1]
138
Austin Schuha9550c02024-10-19 13:48:10 -0700139 Rthetaplusthetas = R(X[STATE_THETA] + X_module[STATE_THETAS0])
Austin Schuh76534f32024-09-02 13:52:45 -0700140
141 caster_vector = jax.numpy.array([-coefficients.caster, 0.0])
142
Austin Schuha9550c02024-10-19 13:48:10 -0700143 robot_velocity = X[STATE_VX:STATE_VY + 1]
Austin Schuh76534f32024-09-02 13:52:45 -0700144
145 contact_patch_velocity = (
Austin Schuha9550c02024-10-19 13:48:10 -0700146 angle_cross(Rtheta @ mounting_location, X[STATE_OMEGA]) +
147 robot_velocity +
148 angle_cross(Rthetaplusthetas @ caster_vector,
149 (X[STATE_OMEGA] + X_module[STATE_OMEGAS0])))
Austin Schuh76534f32024-09-02 13:52:45 -0700150
151 wheel_ground_velocity = Rthetaplusthetas.T @ contact_patch_velocity
152
153 wheel_velocity = jax.numpy.array(
Austin Schuha9550c02024-10-19 13:48:10 -0700154 [coefficients.rw * X_module[STATE_OMEGAD0], 0.0])
Austin Schuh76534f32024-09-02 13:52:45 -0700155
156 wheel_slip_velocity = wheel_velocity - wheel_ground_velocity
157
158 slip_angle = jax.numpy.sin(
159 -soft_atan2(wheel_ground_velocity[1], wheel_ground_velocity[0]))
160
Austin Schuha9550c02024-10-19 13:48:10 -0700161 slip_ratio = (coefficients.rw * X_module[STATE_OMEGAD0] -
Austin Schuh76534f32024-09-02 13:52:45 -0700162 wheel_ground_velocity[0]) / jax.numpy.max(
163 jax.numpy.array(
164 [0.02, jax.numpy.abs(wheel_ground_velocity[0])]))
165
166 Fwx = coefficients.Cx * slip_ratio
167 Fwy = coefficients.Cy * slip_angle
168
169 softsign_velocity = softsign(wheel_ground_velocity[0], 100)
170
171 Ms = -Fwy * (
172 (softsign_velocity * coefficients.contact_patch_length / 3.0) +
173 coefficients.caster)
174
175 alphas = (Ms + coefficients.Kts * Is / coefficients.Gs +
176 (-coefficients.wb + (coefficients.rs + coefficients.rp) *
177 (1 - coefficients.rb1 / coefficients.rp)) *
178 (coefficients.rw / coefficients.rb2 *
179 (-Fwx))) / (coefficients.Jsm +
180 (coefficients.Js /
181 (coefficients.Gs * coefficients.Gs)))
182
183 # Then solve for alphad
184 alphad = (coefficients.rs * coefficients.Jdm * coefficients.Gd3 * alphas +
185 coefficients.rp * coefficients.Ktd * Id * coefficients.Gd -
186 coefficients.rw * coefficients.rp * coefficients.Gd * Fwx *
187 coefficients.Gd) / (coefficients.rp * coefficients.Jdm)
188
189 F = Rthetaplusthetas @ jax.numpy.array([Fwx, Fwy])
190
justinT21d18f79f2024-09-22 19:43:05 -0700191 torque = force_cross(Rtheta @ mounting_location, F)
Austin Schuh76534f32024-09-02 13:52:45 -0700192
193 X_dot_contribution = jax.numpy.hstack((jax.numpy.zeros(
194 (4, )), ) * (module_index) + (jax.numpy.array([
Austin Schuha9550c02024-10-19 13:48:10 -0700195 X_module[STATE_OMEGAS0],
196 X_module[STATE_OMEGAD0],
Austin Schuh76534f32024-09-02 13:52:45 -0700197 alphas,
198 alphad,
199 ]), ) + (jax.numpy.zeros((4, )), ) * (3 - module_index) + (
200 jax.numpy.zeros((3, )),
201 F / coefficients.m,
202 jax.numpy.array([torque / coefficients.J, 0, 0, 0]),
203 ))
204
205 return X_dot_contribution
206
207
208@partial(jax.jit, static_argnames=['coefficients'])
209def full_dynamics(coefficients: CoefficientsType, X, U):
Austin Schuha9550c02024-10-19 13:48:10 -0700210 Rtheta = R(X[STATE_THETA])
Austin Schuh76534f32024-09-02 13:52:45 -0700211
212 module0 = full_module_physics(
213 coefficients, Rtheta, 0,
214 jax.numpy.array(
215 [coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]),
216 X, U)
217 module1 = full_module_physics(
218 coefficients, Rtheta, 1,
219 jax.numpy.array(
220 [-coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]),
221 X, U)
222 module2 = full_module_physics(
223 coefficients, Rtheta, 2,
224 jax.numpy.array(
225 [-coefficients.robot_width / 2.0,
226 -coefficients.robot_width / 2.0]), X, U)
227 module3 = full_module_physics(
228 coefficients, Rtheta, 3,
229 jax.numpy.array(
230 [coefficients.robot_width / 2.0, -coefficients.robot_width / 2.0]),
231 X, U)
232
233 X_dot = module0 + module1 + module2 + module3
234
Austin Schuha9550c02024-10-19 13:48:10 -0700235 X_dot = X_dot.at[STATE_X:STATE_THETA + 1].set(
Austin Schuh76534f32024-09-02 13:52:45 -0700236 jax.numpy.array([
Austin Schuha9550c02024-10-19 13:48:10 -0700237 X[STATE_VX],
238 X[STATE_VY],
239 X[STATE_OMEGA],
Austin Schuh76534f32024-09-02 13:52:45 -0700240 ]))
241
242 return X_dot
243
244
Austin Schuha9550c02024-10-19 13:48:10 -0700245def velocity_module_physics(coefficients: CoefficientsType,
246 Rtheta: jax.typing.ArrayLike, module_index: int,
247 mounting_location: jax.typing.ArrayLike,
248 X: jax.typing.ArrayLike, U: jax.typing.ArrayLike):
Austin Schuh76534f32024-09-02 13:52:45 -0700249 X_module = X[module_index * 2:(module_index + 1) * 2]
250 Is = U[2 * module_index + 0]
251 Id = U[2 * module_index + 1]
252
Austin Schuha9550c02024-10-19 13:48:10 -0700253 rotated_mounting_location = Rtheta @ mounting_location
254
255 Rthetaplusthetas = R(X[VELOCITY_STATE_THETA] +
256 X_module[VELOCITY_STATE_THETAS0])
Austin Schuh76534f32024-09-02 13:52:45 -0700257
258 caster_vector = jax.numpy.array([-coefficients.caster, 0.0])
259
Austin Schuha9550c02024-10-19 13:48:10 -0700260 robot_velocity = X[VELOCITY_STATE_VX:VELOCITY_STATE_VY + 1]
Austin Schuh76534f32024-09-02 13:52:45 -0700261
262 contact_patch_velocity = (
Austin Schuha9550c02024-10-19 13:48:10 -0700263 angle_cross(rotated_mounting_location, X[VELOCITY_STATE_OMEGA]) +
264 robot_velocity + angle_cross(
265 Rthetaplusthetas @ caster_vector,
266 (X[VELOCITY_STATE_OMEGA] + X_module[VELOCITY_STATE_OMEGAS0])))
Austin Schuh76534f32024-09-02 13:52:45 -0700267
Austin Schuha9550c02024-10-19 13:48:10 -0700268 # Velocity of the contact patch over the field projected into the direction
269 # of the wheel.
Austin Schuh76534f32024-09-02 13:52:45 -0700270 wheel_ground_velocity = Rthetaplusthetas.T @ contact_patch_velocity
271
272 slip_angle = jax.numpy.sin(
273 -soft_atan2(wheel_ground_velocity[1], wheel_ground_velocity[0]))
274
275 Fwx = (coefficients.Ktd / (coefficients.Gd * coefficients.rw)) * Id
276 Fwy = coefficients.Cy * slip_angle
277
Austin Schuhffb6db92024-09-04 14:00:48 -0700278 softsign_velocity = softsign(wheel_ground_velocity[0], 100.0)
Austin Schuh76534f32024-09-02 13:52:45 -0700279
280 Ms = -Fwy * (
281 (softsign_velocity * coefficients.contact_patch_length / 3.0) +
282 coefficients.caster)
283
284 alphas = (Ms + coefficients.Kts * Is / coefficients.Gs +
285 (-coefficients.wb + (coefficients.rs + coefficients.rp) *
286 (1 - coefficients.rb1 / coefficients.rp)) *
287 (coefficients.rw / coefficients.rb2 *
288 (-Fwx))) / (coefficients.Jsm +
289 (coefficients.Js /
290 (coefficients.Gs * coefficients.Gs)))
291
292 F = Rthetaplusthetas @ jax.numpy.array([Fwx, Fwy])
293
Austin Schuha9550c02024-10-19 13:48:10 -0700294 torque = force_cross(rotated_mounting_location, F)
Austin Schuh76534f32024-09-02 13:52:45 -0700295
296 X_dot_contribution = jax.numpy.hstack((jax.numpy.zeros(
297 (2, )), ) * (module_index) + (jax.numpy.array([
Austin Schuha9550c02024-10-19 13:48:10 -0700298 X_module[VELOCITY_STATE_OMEGAS0],
Austin Schuh76534f32024-09-02 13:52:45 -0700299 alphas,
300 ]), ) + (jax.numpy.zeros((2, )), ) * (3 - module_index) + (
301 jax.numpy.zeros((1, )),
302 F / coefficients.m,
303 jax.numpy.array([torque / coefficients.J]),
304 ))
305
Austin Schuha9550c02024-10-19 13:48:10 -0700306 return X_dot_contribution, F, torque
Austin Schuh76534f32024-09-02 13:52:45 -0700307
308
309@partial(jax.jit, static_argnames=['coefficients'])
Austin Schuha9550c02024-10-19 13:48:10 -0700310def velocity_dynamics(coefficients: CoefficientsType, X: jax.typing.ArrayLike,
311 U: jax.typing.ArrayLike):
312 Rtheta = R(X[VELOCITY_STATE_THETA])
Austin Schuh76534f32024-09-02 13:52:45 -0700313
Austin Schuha9550c02024-10-19 13:48:10 -0700314 module0, _, _ = velocity_module_physics(
Austin Schuh76534f32024-09-02 13:52:45 -0700315 coefficients, Rtheta, 0,
316 jax.numpy.array(
317 [coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]),
318 X, U)
Austin Schuha9550c02024-10-19 13:48:10 -0700319 module1, _, _ = velocity_module_physics(
Austin Schuh76534f32024-09-02 13:52:45 -0700320 coefficients, Rtheta, 1,
321 jax.numpy.array(
322 [-coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]),
323 X, U)
Austin Schuha9550c02024-10-19 13:48:10 -0700324 module2, _, _ = velocity_module_physics(
Austin Schuh76534f32024-09-02 13:52:45 -0700325 coefficients, Rtheta, 2,
326 jax.numpy.array(
327 [-coefficients.robot_width / 2.0,
328 -coefficients.robot_width / 2.0]), X, U)
Austin Schuha9550c02024-10-19 13:48:10 -0700329 module3, _, _ = velocity_module_physics(
Austin Schuh76534f32024-09-02 13:52:45 -0700330 coefficients, Rtheta, 3,
331 jax.numpy.array(
332 [coefficients.robot_width / 2.0, -coefficients.robot_width / 2.0]),
333 X, U)
334
335 X_dot = module0 + module1 + module2 + module3
336
Austin Schuha9550c02024-10-19 13:48:10 -0700337 return X_dot.at[VELOCITY_STATE_THETA].set(X[VELOCITY_STATE_OMEGA])
338
339
340def to_velocity_state(X):
341 return jax.numpy.array([
342 X[STATE_THETAS0],
343 X[STATE_OMEGAS0],
344 X[STATE_THETAS1],
345 X[STATE_OMEGAS1],
346 X[STATE_THETAS2],
347 X[STATE_OMEGAS2],
348 X[STATE_THETAS3],
349 X[STATE_OMEGAS3],
350 X[STATE_THETA],
351 X[STATE_VX],
352 X[STATE_VY],
353 X[STATE_OMEGA],
354 ])
Austin Schuh5dac2292024-10-19 13:56:58 -0700355
356
Austin Schuhe0cf27d2024-10-26 22:22:20 -0700357@jax.jit
Austin Schuh5dac2292024-10-19 13:56:58 -0700358def mpc_cost(coefficients: CoefficientsType, X, U, goal):
359 J = 0
360
361 rnorm = jax.numpy.linalg.norm(goal[0:2])
362
363 vnorm = jax.lax.select(rnorm > 0.0001, goal[0:2] / rnorm,
364 jax.numpy.array([1.0, 0.0]))
365 vperp = jax.lax.select(rnorm > 0.0001,
366 jax.numpy.array([-vnorm[1], vnorm[0]]),
367 jax.numpy.array([0.0, 1.0]))
368
369 velocity_error = goal[0:2] - X[VELOCITY_STATE_VX:VELOCITY_STATE_VY + 1]
370
371 # TODO(austin): Do we want to do something more special for 0?
372
373 J += 75 * (jax.numpy.dot(velocity_error, vnorm)**2.0)
374 J += 1500 * (jax.numpy.dot(velocity_error, vperp)**2.0)
375 J += 1000 * (goal[2] - X[VELOCITY_STATE_OMEGA])**2.0
376
377 kSteerVelocityGain = 0.10
378 J += kSteerVelocityGain * (X[VELOCITY_STATE_OMEGAS0])**2.0
379 J += kSteerVelocityGain * (X[VELOCITY_STATE_OMEGAS1])**2.0
380 J += kSteerVelocityGain * (X[VELOCITY_STATE_OMEGAS2])**2.0
381 J += kSteerVelocityGain * (X[VELOCITY_STATE_OMEGAS3])**2.0
382
383 mounting_locations = jax.numpy.array(
384 [[coefficients.robot_width / 2.0, coefficients.robot_width / 2.0],
385 [-coefficients.robot_width / 2.0, coefficients.robot_width / 2.0],
386 [-coefficients.robot_width / 2.0, -coefficients.robot_width / 2.0],
387 [coefficients.robot_width / 2.0, -coefficients.robot_width / 2.0]])
388
389 Rtheta = R(X[VELOCITY_STATE_THETA])
390 _, F0, torque0 = velocity_module_physics(coefficients, Rtheta, 0,
391 mounting_locations[0], X, U)
392 _, F1, torque1 = velocity_module_physics(coefficients, Rtheta, 1,
393 mounting_locations[1], X, U)
394 _, F2, torque2 = velocity_module_physics(coefficients, Rtheta, 2,
395 mounting_locations[2], X, U)
396 _, F3, torque3 = velocity_module_physics(coefficients, Rtheta, 3,
397 mounting_locations[3], X, U)
398
399 forces = [F0, F1, F2, F3]
400
401 F = (F0 + F1 + F2 + F3)
402 torque = (torque0 + torque1 + torque2 + torque3)
403
404 def force_cross(torque, r):
405 r_squared_norm = jax.numpy.inner(r, r)
406
407 return jax.numpy.array(
408 [-r[1] * torque / r_squared_norm, r[0] * torque / r_squared_norm])
409
410 # TODO(austin): Are these penalties reasonable? Do they give us a decent time constant?
411 for i in range(4):
412 desired_force = F / 4.0 + force_cross(
413 torque / 4.0, Rtheta @ mounting_locations[i, :])
414 force_error = desired_force - forces[i]
415 J += 0.01 * jax.numpy.inner(force_error, force_error)
416
417 for i in range(4):
418 Is = U[2 * i + 0]
419 Id = U[2 * i + 1]
420 # Steer
421 J += ((Is + STEER_CURRENT_COUPLING_FACTOR * Id)**2.0) / 100000.0
422 # Drive
423 J += (Id**2.0) / 1000.0
424
425 return J