blob: dfadbdab3667d07ac502e9d33689ddcd1589f210 [file] [log] [blame]
Austin Schuh76534f32024-09-02 13:52:45 -07001#!/usr/bin/env python3
2
3from functools import partial
4from collections import namedtuple
5import jax
6
7from frc971.control_loops.swerve import dynamics
8from frc971.control_loops.python.control_loop import KrakenFOC
9
10# Note: this physics needs to match the symengine code. We have tests that
11# confirm they match in all the cases we care about.
12
13CoefficientsType = namedtuple('CoefficientsType', [
14 'Cx',
15 'Cy',
16 'rw',
17 'm',
18 'J',
19 'Gd1',
20 'rs',
21 'rp',
22 'Gd2',
23 'rb1',
24 'rb2',
25 'Gd3',
26 'Gd',
27 'Js',
28 'Gs',
29 'wb',
30 'Jdm',
31 'Jsm',
32 'Kts',
33 'Ktd',
34 'robot_width',
35 'caster',
36 'contact_patch_length',
37])
38
39
40def Coefficients(
41 Cx: float = 25.0 * 9.8 / 4.0 / 0.05,
42 Cy: float = 5 * 9.8 / 0.05 / 4.0,
43 rw: float = 2 * 0.0254,
44
45 # base is 20 kg without battery
46 m: float = 25.0,
47 J: float = 6.0,
48 Gd1: float = 12.0 / 42.0,
49 rs: float = 28.0 / 20.0 / 2.0,
50 rp: float = 18.0 / 20.0 / 2.0,
51
52 # 15 / 45 bevel ratio, calculated using python script ported over to
53 # GetBevelPitchRadius(double)
54 # TODO(Justin): Use the function instead of computed constantss
55 rb1: float = 0.3805473,
56 rb2: float = 1.14164,
57 Js: float = 0.001,
58 Gs: float = 35.0 / 468.0,
59 wb: float = 0.725,
60 drive_motor=KrakenFOC(),
61 steer_motor=KrakenFOC(),
62 robot_width: float = 24.75 * 0.0254,
63 caster: float = 0.01,
64 contact_patch_length: float = 0.02,
65) -> CoefficientsType:
66
67 Gd2 = rs / rp
68 Gd3 = rb1 / rb2
69 Gd = Gd1 * Gd2 * Gd3
70
71 Jdm = drive_motor.motor_inertia
72 Jsm = steer_motor.motor_inertia
73 Kts = steer_motor.Kt
74 Ktd = drive_motor.Kt
75
76 return CoefficientsType(
77 Cx=Cx,
78 Cy=Cy,
79 rw=rw,
80 m=m,
81 J=J,
82 Gd1=Gd1,
83 rs=rs,
84 rp=rp,
85 Gd2=Gd2,
86 rb1=rb1,
87 rb2=rb2,
88 Gd3=Gd3,
89 Gd=Gd,
90 Js=Js,
91 Gs=Gs,
92 wb=wb,
93 Jdm=Jdm,
94 Jsm=Jsm,
95 Kts=Kts,
96 Ktd=Ktd,
97 robot_width=robot_width,
98 caster=caster,
99 contact_patch_length=contact_patch_length,
100 )
101
102
103def R(theta):
104 stheta = jax.numpy.sin(theta)
105 ctheta = jax.numpy.cos(theta)
106 return jax.numpy.array([[ctheta, -stheta], [stheta, ctheta]])
107
108
109def angle_cross(vector, omega):
110 return jax.numpy.array([-vector[1] * omega, vector[0] * omega])
111
112
113def force_cross(r, f):
114 return r[0] * f[1] - r[1] * f[0]
115
116
117def softsign(x, gain):
Austin Schuhffb6db92024-09-04 14:00:48 -0700118 return 1 - 2.0 * jax.nn.sigmoid(-gain * x)
Austin Schuh76534f32024-09-02 13:52:45 -0700119
120
121def soft_atan2(y, x):
122 kMaxLogGain = 1.0 / 0.05
123 kAbsLogGain = 1.0 / 0.01
124
Austin Schuhffb6db92024-09-04 14:00:48 -0700125 softabs_x = x * softsign(x, kAbsLogGain)
Austin Schuh76534f32024-09-02 13:52:45 -0700126
127 return jax.numpy.arctan2(
128 y,
129 jax.scipy.special.logsumexp(
130 jax.numpy.array([1.0, softabs_x * kMaxLogGain])) / kMaxLogGain)
131
132
133def full_module_physics(coefficients: dict, Rtheta, module_index: int,
134 mounting_location, X, U):
135 X_module = X[module_index * 4:(module_index + 1) * 4]
136 Is = U[2 * module_index + 0]
137 Id = U[2 * module_index + 1]
138
139 Rthetaplusthetas = R(X[dynamics.STATE_THETA] +
140 X_module[dynamics.STATE_THETAS0])
141
142 caster_vector = jax.numpy.array([-coefficients.caster, 0.0])
143
144 robot_velocity = X[dynamics.STATE_VX:dynamics.STATE_VY + 1]
145
146 contact_patch_velocity = (
147 angle_cross(Rtheta @ mounting_location, X[dynamics.STATE_OMEGA]) +
148 robot_velocity + angle_cross(
149 Rthetaplusthetas @ caster_vector,
150 (X[dynamics.STATE_OMEGA] + X_module[dynamics.STATE_OMEGAS0])))
151
152 wheel_ground_velocity = Rthetaplusthetas.T @ contact_patch_velocity
153
154 wheel_velocity = jax.numpy.array(
155 [coefficients.rw * X_module[dynamics.STATE_OMEGAD0], 0.0])
156
157 wheel_slip_velocity = wheel_velocity - wheel_ground_velocity
158
159 slip_angle = jax.numpy.sin(
160 -soft_atan2(wheel_ground_velocity[1], wheel_ground_velocity[0]))
161
162 slip_ratio = (coefficients.rw * X_module[dynamics.STATE_OMEGAD0] -
163 wheel_ground_velocity[0]) / jax.numpy.max(
164 jax.numpy.array(
165 [0.02, jax.numpy.abs(wheel_ground_velocity[0])]))
166
167 Fwx = coefficients.Cx * slip_ratio
168 Fwy = coefficients.Cy * slip_angle
169
170 softsign_velocity = softsign(wheel_ground_velocity[0], 100)
171
172 Ms = -Fwy * (
173 (softsign_velocity * coefficients.contact_patch_length / 3.0) +
174 coefficients.caster)
175
176 alphas = (Ms + coefficients.Kts * Is / coefficients.Gs +
177 (-coefficients.wb + (coefficients.rs + coefficients.rp) *
178 (1 - coefficients.rb1 / coefficients.rp)) *
179 (coefficients.rw / coefficients.rb2 *
180 (-Fwx))) / (coefficients.Jsm +
181 (coefficients.Js /
182 (coefficients.Gs * coefficients.Gs)))
183
184 # Then solve for alphad
185 alphad = (coefficients.rs * coefficients.Jdm * coefficients.Gd3 * alphas +
186 coefficients.rp * coefficients.Ktd * Id * coefficients.Gd -
187 coefficients.rw * coefficients.rp * coefficients.Gd * Fwx *
188 coefficients.Gd) / (coefficients.rp * coefficients.Jdm)
189
190 F = Rthetaplusthetas @ jax.numpy.array([Fwx, Fwy])
191
justinT21d18f79f2024-09-22 19:43:05 -0700192 torque = force_cross(Rtheta @ mounting_location, F)
Austin Schuh76534f32024-09-02 13:52:45 -0700193
194 X_dot_contribution = jax.numpy.hstack((jax.numpy.zeros(
195 (4, )), ) * (module_index) + (jax.numpy.array([
196 X_module[dynamics.STATE_OMEGAS0],
197 X_module[dynamics.STATE_OMEGAD0],
198 alphas,
199 alphad,
200 ]), ) + (jax.numpy.zeros((4, )), ) * (3 - module_index) + (
201 jax.numpy.zeros((3, )),
202 F / coefficients.m,
203 jax.numpy.array([torque / coefficients.J, 0, 0, 0]),
204 ))
205
206 return X_dot_contribution
207
208
209@partial(jax.jit, static_argnames=['coefficients'])
210def full_dynamics(coefficients: CoefficientsType, X, U):
211 Rtheta = R(X[dynamics.STATE_THETA])
212
213 module0 = full_module_physics(
214 coefficients, Rtheta, 0,
215 jax.numpy.array(
216 [coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]),
217 X, U)
218 module1 = full_module_physics(
219 coefficients, Rtheta, 1,
220 jax.numpy.array(
221 [-coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]),
222 X, U)
223 module2 = full_module_physics(
224 coefficients, Rtheta, 2,
225 jax.numpy.array(
226 [-coefficients.robot_width / 2.0,
227 -coefficients.robot_width / 2.0]), X, U)
228 module3 = full_module_physics(
229 coefficients, Rtheta, 3,
230 jax.numpy.array(
231 [coefficients.robot_width / 2.0, -coefficients.robot_width / 2.0]),
232 X, U)
233
234 X_dot = module0 + module1 + module2 + module3
235
236 X_dot = X_dot.at[dynamics.STATE_X:dynamics.STATE_THETA + 1].set(
237 jax.numpy.array([
238 X[dynamics.STATE_VX],
239 X[dynamics.STATE_VY],
240 X[dynamics.STATE_OMEGA],
241 ]))
242
243 return X_dot
244
245
246def velocity_module_physics(coefficients: dict, Rtheta, module_index: int,
247 mounting_location, X, U):
248 X_module = X[module_index * 2:(module_index + 1) * 2]
249 Is = U[2 * module_index + 0]
250 Id = U[2 * module_index + 1]
251
252 Rthetaplusthetas = R(X[dynamics.VELOCITY_STATE_THETA] +
253 X_module[dynamics.VELOCITY_STATE_THETAS0])
254
255 caster_vector = jax.numpy.array([-coefficients.caster, 0.0])
256
257 robot_velocity = X[dynamics.VELOCITY_STATE_VX:dynamics.VELOCITY_STATE_VY +
258 1]
259
260 contact_patch_velocity = (
261 angle_cross(Rtheta @ mounting_location,
262 X[dynamics.VELOCITY_STATE_OMEGA]) + robot_velocity +
263 angle_cross(Rthetaplusthetas @ caster_vector,
264 (X[dynamics.VELOCITY_STATE_OMEGA] +
265 X_module[dynamics.VELOCITY_STATE_OMEGAS0])))
266
267 wheel_ground_velocity = Rthetaplusthetas.T @ contact_patch_velocity
268
269 slip_angle = jax.numpy.sin(
270 -soft_atan2(wheel_ground_velocity[1], wheel_ground_velocity[0]))
271
272 Fwx = (coefficients.Ktd / (coefficients.Gd * coefficients.rw)) * Id
273 Fwy = coefficients.Cy * slip_angle
274
Austin Schuhffb6db92024-09-04 14:00:48 -0700275 softsign_velocity = softsign(wheel_ground_velocity[0], 100.0)
Austin Schuh76534f32024-09-02 13:52:45 -0700276
277 Ms = -Fwy * (
278 (softsign_velocity * coefficients.contact_patch_length / 3.0) +
279 coefficients.caster)
280
281 alphas = (Ms + coefficients.Kts * Is / coefficients.Gs +
282 (-coefficients.wb + (coefficients.rs + coefficients.rp) *
283 (1 - coefficients.rb1 / coefficients.rp)) *
284 (coefficients.rw / coefficients.rb2 *
285 (-Fwx))) / (coefficients.Jsm +
286 (coefficients.Js /
287 (coefficients.Gs * coefficients.Gs)))
288
289 F = Rthetaplusthetas @ jax.numpy.array([Fwx, Fwy])
290
justinT21d18f79f2024-09-22 19:43:05 -0700291 torque = force_cross(Rtheta @ mounting_location, F)
Austin Schuh76534f32024-09-02 13:52:45 -0700292
293 X_dot_contribution = jax.numpy.hstack((jax.numpy.zeros(
294 (2, )), ) * (module_index) + (jax.numpy.array([
295 X_module[dynamics.VELOCITY_STATE_OMEGAS0],
296 alphas,
297 ]), ) + (jax.numpy.zeros((2, )), ) * (3 - module_index) + (
298 jax.numpy.zeros((1, )),
299 F / coefficients.m,
300 jax.numpy.array([torque / coefficients.J]),
301 ))
302
303 return X_dot_contribution
304
305
306@partial(jax.jit, static_argnames=['coefficients'])
307def velocity_dynamics(coefficients: CoefficientsType, X, U):
308 Rtheta = R(X[dynamics.VELOCITY_STATE_THETA])
309
310 module0 = velocity_module_physics(
311 coefficients, Rtheta, 0,
312 jax.numpy.array(
313 [coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]),
314 X, U)
315 module1 = velocity_module_physics(
316 coefficients, Rtheta, 1,
317 jax.numpy.array(
318 [-coefficients.robot_width / 2.0, coefficients.robot_width / 2.0]),
319 X, U)
320 module2 = velocity_module_physics(
321 coefficients, Rtheta, 2,
322 jax.numpy.array(
323 [-coefficients.robot_width / 2.0,
324 -coefficients.robot_width / 2.0]), X, U)
325 module3 = velocity_module_physics(
326 coefficients, Rtheta, 3,
327 jax.numpy.array(
328 [coefficients.robot_width / 2.0, -coefficients.robot_width / 2.0]),
329 X, U)
330
331 X_dot = module0 + module1 + module2 + module3
332
333 return X_dot.at[dynamics.VELOCITY_STATE_THETA].set(
334 X[dynamics.VELOCITY_STATE_OMEGA])