blob: 176130879eb8a276be537be6067252279b10d8aa [file] [log] [blame]
justinT21446e4f62024-06-16 22:36:10 -07001#include <symengine/add.h>
2#include <symengine/matrix.h>
3#include <symengine/number.h>
4#include <symengine/printers.h>
5#include <symengine/real_double.h>
6#include <symengine/simplify.h>
7#include <symengine/solve.h>
8#include <symengine/symbol.h>
9
10#include <array>
11#include <cmath>
12#include <numbers>
13#include <utility>
14
15#include "absl/strings/str_format.h"
16#include "absl/strings/str_join.h"
17#include "absl/strings/str_replace.h"
18#include "absl/strings/substitute.h"
19#include "gflags/gflags.h"
20#include "glog/logging.h"
21
22#include "aos/init.h"
23#include "aos/util/file.h"
24#include "frc971/control_loops/swerve/motors.h"
25
26DEFINE_string(output_base, "",
27 "Path to strip off the front of the output paths.");
28DEFINE_string(cc_output_path, "", "Path to write generated header code to");
29DEFINE_string(h_output_path, "", "Path to write generated cc code to");
justinT21942892b2024-07-02 22:33:50 -070030DEFINE_string(py_output_path, "", "Path to write generated py code to");
Austin Schuh0f881092024-06-28 15:36:48 -070031DEFINE_string(casadi_py_output_path, "",
32 "Path to write casadi generated py code to");
justinT21446e4f62024-06-16 22:36:10 -070033
34DEFINE_bool(symbolic, false, "If true, write everything out symbolically.");
35
justinT21942892b2024-07-02 22:33:50 -070036using SymEngine::abs;
justinT21446e4f62024-06-16 22:36:10 -070037using SymEngine::add;
38using SymEngine::atan2;
39using SymEngine::Basic;
40using SymEngine::ccode;
41using SymEngine::cos;
42using SymEngine::DenseMatrix;
43using SymEngine::div;
44using SymEngine::Inf;
45using SymEngine::integer;
46using SymEngine::map_basic_basic;
47using SymEngine::minus_one;
48using SymEngine::neg;
49using SymEngine::NegInf;
50using SymEngine::pow;
51using SymEngine::RCP;
52using SymEngine::real_double;
53using SymEngine::RealDouble;
54using SymEngine::Set;
55using SymEngine::simplify;
56using SymEngine::sin;
57using SymEngine::solve;
58using SymEngine::symbol;
59using SymEngine::Symbol;
60
61namespace frc971::control_loops::swerve {
62
63// State per module.
64struct Module {
65 RCP<const Symbol> Is;
66
67 RCP<const Symbol> Id;
68
69 RCP<const Symbol> thetas;
70 RCP<const Symbol> omegas;
71 RCP<const Symbol> alphas;
72 RCP<const Basic> alphas_eqn;
73
74 RCP<const Symbol> thetad;
75 RCP<const Symbol> omegad;
76 RCP<const Symbol> alphad;
77 RCP<const Basic> alphad_eqn;
78
Austin Schuhb67a38f2024-07-04 13:48:38 -070079 DenseMatrix wheel_ground_velocity;
80 RCP<const Basic> slip_angle;
81 RCP<const Basic> slip_ratio;
82
83 RCP<const Basic> Fwx;
84 RCP<const Basic> Fwy;
85
justinT21446e4f62024-06-16 22:36:10 -070086 // Acceleration contribution from this module.
87 DenseMatrix accel;
88 RCP<const Basic> angular_accel;
89};
90
91class SwerveSimulation {
92 public:
93 SwerveSimulation() : drive_motor_(KrakenFOC()), steer_motor_(KrakenFOC()) {
94 auto fx = symbol("fx");
95 auto fy = symbol("fy");
96 auto moment = symbol("moment");
97
98 if (FLAGS_symbolic) {
99 Cx_ = symbol("Cx");
100 Cy_ = symbol("Cy");
101
102 r_w_ = symbol("r_w_");
103
104 m_ = symbol("m");
105 J_ = symbol("J");
106
107 Gd1_ = symbol("Gd1");
108 rs_ = symbol("rs");
109 rp_ = symbol("rp");
110 Gd2_ = symbol("Gd2");
111
112 rb1_ = symbol("rb1");
113 rb2_ = symbol("rb2");
114
115 Gd2_ = symbol("Gd3");
116 Gd_ = symbol("Gd");
117
118 Js_ = symbol("Js");
119
120 Gs_ = symbol("Gs");
121 wb_ = symbol("wb");
122
123 Jdm_ = symbol("Jdm");
124 Jsm_ = symbol("Jsm");
125 Kts_ = symbol("Kts");
126 Ktd_ = symbol("Ktd");
127
128 robot_width_ = symbol("robot_width");
129
130 caster_ = symbol("caster");
131 contact_patch_length_ = symbol("Lcp");
132 } else {
133 Cx_ = real_double(5 * 9.8 / 0.05 / 4.0);
134 Cy_ = real_double(5 * 9.8 / 0.05 / 4.0);
135
136 r_w_ = real_double(2 * 0.0254);
137
138 m_ = real_double(25.0); // base is 20 kg without battery
139 J_ = real_double(6.0);
140
141 Gd1_ = real_double(12.0 / 42.0);
142 rs_ = real_double(28.0 / 20.0 / 2.0);
143 rp_ = real_double(18.0 / 20.0 / 2.0);
144 Gd2_ = div(rs_, rp_);
145
146 // 15 / 45 bevel ratio, calculated using python script ported over to
147 // GetBevelPitchRadius(double
148 // TODO(Justin): Use the function instead of computed constantss
149 rb1_ = real_double(0.3805473);
150 rb2_ = real_double(1.14164);
151
152 Gd3_ = div(rb1_, rb2_);
153 Gd_ = mul(mul(Gd1_, Gd2_), Gd3_);
154
155 Js_ = real_double(0.1);
156
157 Gs_ = real_double(35.0 / 468.0);
158 wb_ = real_double(0.725);
159
160 Jdm_ = real_double(drive_motor_.motor_inertia);
161 Jsm_ = real_double(steer_motor_.motor_inertia);
162 Kts_ = real_double(steer_motor_.Kt);
163 Ktd_ = real_double(drive_motor_.Kt);
164
165 robot_width_ = real_double(24.75 * 0.0254);
166
167 caster_ = real_double(0.01);
168 contact_patch_length_ = real_double(0.02);
169 }
170
171 x_ = symbol("x");
172 y_ = symbol("y");
173 theta_ = symbol("theta");
174
175 vx_ = symbol("vx");
176 vy_ = symbol("vy");
177 omega_ = symbol("omega");
178
179 ax_ = symbol("ax");
180 ay_ = symbol("ay");
181 atheta_ = symbol("atheta");
182
183 // Now, compute the accelerations due to the disturbance forces.
184 angular_accel_ = div(moment, J_);
185 DenseMatrix external_accel = DenseMatrix(2, 1, {div(fx, m_), div(fy, m_)});
186
187 // And compute the physics contributions from each module.
188 modules_[0] = ModulePhysics(
189 0, DenseMatrix(
190 2, 1,
191 {div(robot_width_, integer(2)), div(robot_width_, integer(2))}));
192 modules_[1] =
193 ModulePhysics(1, DenseMatrix(2, 1,
194 {div(robot_width_, integer(-2)),
195 div(robot_width_, integer(2))}));
196 modules_[2] =
197 ModulePhysics(2, DenseMatrix(2, 1,
198 {div(robot_width_, integer(-2)),
199 div(robot_width_, integer(-2))}));
200 modules_[3] =
201 ModulePhysics(3, DenseMatrix(2, 1,
202 {div(robot_width_, integer(2)),
203 div(robot_width_, integer(-2))}));
204
205 // And convert them into the overall robot contribution.
206 DenseMatrix temp0 = DenseMatrix(2, 1);
207 DenseMatrix temp1 = DenseMatrix(2, 1);
208 DenseMatrix temp2 = DenseMatrix(2, 1);
209 accel_ = DenseMatrix(2, 1);
210
211 add_dense_dense(modules_[0].accel, external_accel, temp0);
212 add_dense_dense(temp0, modules_[1].accel, temp1);
213 add_dense_dense(temp1, modules_[2].accel, temp2);
214 add_dense_dense(temp2, modules_[3].accel, accel_);
215
216 angular_accel_ = add(angular_accel_, modules_[0].angular_accel);
217 angular_accel_ = add(angular_accel_, modules_[1].angular_accel);
218 angular_accel_ = add(angular_accel_, modules_[2].angular_accel);
219 angular_accel_ = simplify(add(angular_accel_, modules_[3].angular_accel));
220
221 VLOG(1) << "accel(0, 0) = " << ccode(*accel_.get(0, 0));
222 VLOG(1) << "accel(1, 0) = " << ccode(*accel_.get(1, 0));
223 VLOG(1) << "angular_accel = " << ccode(*angular_accel_);
224 }
225
justinT21942892b2024-07-02 22:33:50 -0700226 // Writes the physics out to the provided .py path.
227 void WritePy(std::string_view py_path) {
228 std::vector<std::string> result_py;
229
230 result_py.emplace_back("#!/usr/bin/python3");
231 result_py.emplace_back("");
232 result_py.emplace_back("import numpy");
justinT21942892b2024-07-02 22:33:50 -0700233 result_py.emplace_back("");
234
justinT21942892b2024-07-02 22:33:50 -0700235 result_py.emplace_back("def swerve_physics(t, X, U_func):");
Austin Schuh0f881092024-06-28 15:36:48 -0700236 result_py.emplace_back(" def atan2(y, x):");
237 result_py.emplace_back(" if x < 0:");
238 result_py.emplace_back(" return -numpy.atan2(y, x)");
239 result_py.emplace_back(" else:");
240 result_py.emplace_back(" return numpy.atan2(y, x)");
241 result_py.emplace_back(" sin = numpy.sin");
242 result_py.emplace_back(" cos = numpy.cos");
243 result_py.emplace_back(" fabs = numpy.fabs");
244
justinT21942892b2024-07-02 22:33:50 -0700245 result_py.emplace_back(" result = numpy.empty([25, 1])");
246 result_py.emplace_back(" X = X.reshape(25, 1)");
247 result_py.emplace_back(" U = U_func(X)");
248 result_py.emplace_back("");
249
250 // Start by writing out variables matching each of the symbol names we use
251 // so we don't have to modify the computed equations too much.
252 for (size_t m = 0; m < kNumModules; ++m) {
253 result_py.emplace_back(
254 absl::Substitute(" thetas$0 = X[$1, 0]", m, m * 4));
255 result_py.emplace_back(
256 absl::Substitute(" omegas$0 = X[$1, 0]", m, m * 4 + 2));
257 result_py.emplace_back(
258 absl::Substitute(" omegad$0 = X[$1, 0]", m, m * 4 + 3));
259 }
260
261 result_py.emplace_back(
262 absl::Substitute(" theta = X[$0, 0]", kNumModules * 4 + 2));
263 result_py.emplace_back(
264 absl::Substitute(" vx = X[$0, 0]", kNumModules * 4 + 3));
265 result_py.emplace_back(
266 absl::Substitute(" vy = X[$0, 0]", kNumModules * 4 + 4));
267 result_py.emplace_back(
268 absl::Substitute(" omega = X[$0, 0]", kNumModules * 4 + 5));
269
270 result_py.emplace_back(
271 absl::Substitute(" fx = X[$0, 0]", kNumModules * 4 + 6));
272 result_py.emplace_back(
273 absl::Substitute(" fy = X[$0, 0]", kNumModules * 4 + 7));
274 result_py.emplace_back(
275 absl::Substitute(" moment = X[$0, 0]", kNumModules * 4 + 8));
276
277 // Now do the same for the inputs.
278 for (size_t m = 0; m < kNumModules; ++m) {
279 result_py.emplace_back(absl::Substitute(" Is$0 = U[$1, 0]", m, m * 2));
280 result_py.emplace_back(
281 absl::Substitute(" Id$0 = U[$1, 0]", m, m * 2 + 1));
282 }
283
284 result_py.emplace_back("");
285
286 // And then write out the derivative of each state.
287 for (size_t m = 0; m < kNumModules; ++m) {
288 result_py.emplace_back(
289 absl::Substitute(" result[$0, 0] = omegas$1", m * 4, m));
290 result_py.emplace_back(
291 absl::Substitute(" result[$0, 0] = omegad$1", m * 4 + 1, m));
292
293 result_py.emplace_back(absl::Substitute(
294 " result[$0, 0] = $1", m * 4 + 2, ccode(*modules_[m].alphas_eqn)));
295 result_py.emplace_back(absl::Substitute(
296 " result[$0, 0] = $1", m * 4 + 3, ccode(*modules_[m].alphad_eqn)));
297 }
298
299 result_py.emplace_back(
300 absl::Substitute(" result[$0, 0] = vx", kNumModules * 4));
301 result_py.emplace_back(
302 absl::Substitute(" result[$0, 0] = vy", kNumModules * 4 + 1));
303 result_py.emplace_back(
304 absl::Substitute(" result[$0, 0] = omega", kNumModules * 4 + 2));
305
306 result_py.emplace_back(absl::Substitute(" result[$0, 0] = $1",
307 kNumModules * 4 + 3,
308 ccode(*accel_.get(0, 0))));
309 result_py.emplace_back(absl::Substitute(" result[$0, 0] = $1",
310 kNumModules * 4 + 4,
311 ccode(*accel_.get(1, 0))));
312 result_py.emplace_back(absl::Substitute(
313 " result[$0, 0] = $1", kNumModules * 4 + 5, ccode(*angular_accel_)));
314
315 result_py.emplace_back(
316 absl::Substitute(" result[$0, 0] = 0.0", kNumModules * 4 + 6));
317 result_py.emplace_back(
318 absl::Substitute(" result[$0, 0] = 0.0", kNumModules * 4 + 7));
319 result_py.emplace_back(
320 absl::Substitute(" result[$0, 0] = 0.0", kNumModules * 4 + 8));
321
322 result_py.emplace_back("");
323 result_py.emplace_back(" return result.reshape(25,)\n");
324
325 aos::util::WriteStringToFileOrDie(py_path, absl::StrJoin(result_py, "\n"));
326 }
327
justinT21446e4f62024-06-16 22:36:10 -0700328 // Writes the physics out to the provided .cc and .h path.
329 void Write(std::string_view cc_path, std::string_view h_path) {
330 std::vector<std::string> result_cc;
331 std::vector<std::string> result_h;
332
Austin Schuh0f881092024-06-28 15:36:48 -0700333 std::string_view include_guard_stripped = h_path;
justinT21446e4f62024-06-16 22:36:10 -0700334 CHECK(absl::ConsumePrefix(&include_guard_stripped, FLAGS_output_base));
335 std::string include_guard =
336 absl::StrReplaceAll(absl::AsciiStrToUpper(include_guard_stripped),
337 {{"/", "_"}, {".", "_"}});
338
339 // Write out the header.
340 result_h.emplace_back(absl::Substitute("#ifndef $0_", include_guard));
341 result_h.emplace_back(absl::Substitute("#define $0_", include_guard));
342 result_h.emplace_back("");
343 result_h.emplace_back("#include <Eigen/Dense>");
344 result_h.emplace_back("");
345 result_h.emplace_back("namespace frc971::control_loops::swerve {");
346 result_h.emplace_back("");
347 result_h.emplace_back("// Returns the derivative of our state vector");
348 result_h.emplace_back("// [thetas0, thetad0, omegas0, omegad0,");
349 result_h.emplace_back("// thetas1, thetad1, omegas1, omegad1,");
350 result_h.emplace_back("// thetas2, thetad2, omegas2, omegad2,");
351 result_h.emplace_back("// thetas3, thetad3, omegas3, omegad3,");
352 result_h.emplace_back("// x, y, theta, vx, vy, omega,");
353 result_h.emplace_back("// Fx, Fy, Moment]");
354 result_h.emplace_back("Eigen::Matrix<double, 25, 1> SwervePhysics(");
355 result_h.emplace_back(
356 " Eigen::Map<const Eigen::Matrix<double, 25, 1>> X,");
357 result_h.emplace_back(
358 " Eigen::Map<const Eigen::Matrix<double, 8, 1>> U);");
359 result_h.emplace_back("");
360 result_h.emplace_back("} // namespace frc971::control_loops::swerve");
361 result_h.emplace_back("");
362 result_h.emplace_back(absl::Substitute("#endif // $0_", include_guard));
363
364 // Write out the .cc
365 result_cc.emplace_back(
366 absl::Substitute("#include \"$0\"", include_guard_stripped));
367 result_cc.emplace_back("");
368 result_cc.emplace_back("#include <cmath>");
369 result_cc.emplace_back("");
370 result_cc.emplace_back("namespace frc971::control_loops::swerve {");
371 result_cc.emplace_back("");
372 result_cc.emplace_back("Eigen::Matrix<double, 25, 1> SwervePhysics(");
373 result_cc.emplace_back(
374 " Eigen::Map<const Eigen::Matrix<double, 25, 1>> X,");
375 result_cc.emplace_back(
376 " Eigen::Map<const Eigen::Matrix<double, 8, 1>> U) {");
377 result_cc.emplace_back(" Eigen::Matrix<double, 25, 1> result;");
378
379 // Start by writing out variables matching each of the symbol names we use
380 // so we don't have to modify the computed equations too much.
381 for (size_t m = 0; m < kNumModules; ++m) {
382 result_cc.emplace_back(
383 absl::Substitute(" const double thetas$0 = X($1, 0);", m, m * 4));
384 result_cc.emplace_back(absl::Substitute(
385 " const double omegas$0 = X($1, 0);", m, m * 4 + 2));
386 result_cc.emplace_back(absl::Substitute(
387 " const double omegad$0 = X($1, 0);", m, m * 4 + 3));
388 }
389
390 result_cc.emplace_back(absl::Substitute(" const double theta = X($0, 0);",
391 kNumModules * 4 + 2));
392 result_cc.emplace_back(
393 absl::Substitute(" const double vx = X($0, 0);", kNumModules * 4 + 3));
394 result_cc.emplace_back(
395 absl::Substitute(" const double vy = X($0, 0);", kNumModules * 4 + 4));
396 result_cc.emplace_back(absl::Substitute(" const double omega = X($0, 0);",
397 kNumModules * 4 + 5));
398
399 result_cc.emplace_back(
400 absl::Substitute(" const double fx = X($0, 0);", kNumModules * 4 + 6));
401 result_cc.emplace_back(
402 absl::Substitute(" const double fy = X($0, 0);", kNumModules * 4 + 7));
403 result_cc.emplace_back(absl::Substitute(" const double moment = X($0, 0);",
404 kNumModules * 4 + 8));
405
406 // Now do the same for the inputs.
407 for (size_t m = 0; m < kNumModules; ++m) {
408 result_cc.emplace_back(
409 absl::Substitute(" const double Is$0 = U($1, 0);", m, m * 2));
410 result_cc.emplace_back(
411 absl::Substitute(" const double Id$0 = U($1, 0);", m, m * 2 + 1));
412 }
413
414 result_cc.emplace_back("");
415
416 // And then write out the derivative of each state.
417 for (size_t m = 0; m < kNumModules; ++m) {
418 result_cc.emplace_back(
419 absl::Substitute(" result($0, 0) = omegas$1;", m * 4, m));
420 result_cc.emplace_back(
421 absl::Substitute(" result($0, 0) = omegad$1;", m * 4 + 1, m));
422
423 result_cc.emplace_back(absl::Substitute(
424 " result($0, 0) = $1;", m * 4 + 2, ccode(*modules_[m].alphas_eqn)));
425 result_cc.emplace_back(absl::Substitute(
426 " result($0, 0) = $1;", m * 4 + 3, ccode(*modules_[m].alphad_eqn)));
427 }
428
429 result_cc.emplace_back(
430 absl::Substitute(" result($0, 0) = omega;", kNumModules * 4));
431 result_cc.emplace_back(
432 absl::Substitute(" result($0, 0) = vx;", kNumModules * 4 + 1));
433 result_cc.emplace_back(
434 absl::Substitute(" result($0, 0) = vy;", kNumModules * 4 + 2));
435
436 result_cc.emplace_back(absl::Substitute(
437 " result($0, 0) = $1;", kNumModules * 4 + 3, ccode(*angular_accel_)));
438 result_cc.emplace_back(absl::Substitute(" result($0, 0) = $1;",
439 kNumModules * 4 + 4,
440 ccode(*accel_.get(0, 0))));
441 result_cc.emplace_back(absl::Substitute(" result($0, 0) = $1;",
442 kNumModules * 4 + 5,
443 ccode(*accel_.get(1, 0))));
444
445 result_cc.emplace_back(
446 absl::Substitute(" result($0, 0) = 0.0;", kNumModules * 4 + 6));
447 result_cc.emplace_back(
448 absl::Substitute(" result($0, 0) = 0.0;", kNumModules * 4 + 7));
449 result_cc.emplace_back(
450 absl::Substitute(" result($0, 0) = 0.0;", kNumModules * 4 + 8));
451
452 result_cc.emplace_back("");
453 result_cc.emplace_back(" return result;");
454 result_cc.emplace_back("}");
455 result_cc.emplace_back("");
456 result_cc.emplace_back("} // namespace frc971::control_loops::swerve");
457
458 aos::util::WriteStringToFileOrDie(cc_path, absl::StrJoin(result_cc, "\n"));
459 aos::util::WriteStringToFileOrDie(h_path, absl::StrJoin(result_h, "\n"));
460 }
461
Austin Schuhb67a38f2024-07-04 13:48:38 -0700462 void WriteCasadiVariables(std::vector<std::string> *result_py) {
463 result_py->emplace_back(" sin = casadi.sin");
464 result_py->emplace_back(" cos = casadi.cos");
465 result_py->emplace_back(" atan2 = casadi.atan2");
466 result_py->emplace_back(" fabs = casadi.fabs");
467
468 // Start by writing out variables matching each of the symbol names we use
469 // so we don't have to modify the computed equations too much.
470 for (size_t m = 0; m < kNumModules; ++m) {
471 result_py->emplace_back(
472 absl::Substitute(" thetas$0 = X[$1, 0]", m, m * 4));
473 result_py->emplace_back(
474 absl::Substitute(" omegas$0 = X[$1, 0]", m, m * 4 + 2));
475 result_py->emplace_back(
476 absl::Substitute(" omegad$0 = X[$1, 0]", m, m * 4 + 3));
477 }
478
479 result_py->emplace_back(
480 absl::Substitute(" theta = X[$0, 0]", kNumModules * 4 + 2));
481 result_py->emplace_back(
482 absl::Substitute(" vx = X[$0, 0]", kNumModules * 4 + 3));
483 result_py->emplace_back(
484 absl::Substitute(" vy = X[$0, 0]", kNumModules * 4 + 4));
485 result_py->emplace_back(
486 absl::Substitute(" omega = X[$0, 0]", kNumModules * 4 + 5));
487
488 result_py->emplace_back(
489 absl::Substitute(" fx = X[$0, 0]", kNumModules * 4 + 6));
490 result_py->emplace_back(
491 absl::Substitute(" fy = X[$0, 0]", kNumModules * 4 + 7));
492 result_py->emplace_back(
493 absl::Substitute(" moment = X[$0, 0]", kNumModules * 4 + 8));
494
495 // Now do the same for the inputs.
496 for (size_t m = 0; m < kNumModules; ++m) {
497 result_py->emplace_back(
498 absl::Substitute(" Is$0 = U[$1, 0]", m, m * 2));
499 result_py->emplace_back(
500 absl::Substitute(" Id$0 = U[$1, 0]", m, m * 2 + 1));
501 }
502 }
503
Austin Schuh0f881092024-06-28 15:36:48 -0700504 // Writes the physics out to the provided .cc and .h path.
505 void WriteCasadi(std::string_view py_path) {
506 std::vector<std::string> result_py;
507
508 // Write out the header.
509 result_py.emplace_back("#!/usr/bin/python3");
510 result_py.emplace_back("");
511 result_py.emplace_back("import casadi");
512 result_py.emplace_back("");
513 result_py.emplace_back("# Returns the derivative of our state vector");
514 result_py.emplace_back("# Returns the derivative of our state vector");
515 result_py.emplace_back("# [thetas0, thetad0, omegas0, omegad0,");
516 result_py.emplace_back("# thetas1, thetad1, omegas1, omegad1,");
517 result_py.emplace_back("# thetas2, thetad2, omegas2, omegad2,");
518 result_py.emplace_back("# thetas3, thetad3, omegas3, omegad3,");
519 result_py.emplace_back("# x, y, theta, vx, vy, omega,");
520 result_py.emplace_back("# Fx, Fy, Moment]");
521 result_py.emplace_back("def swerve_physics(X, U):");
Austin Schuhb67a38f2024-07-04 13:48:38 -0700522 WriteCasadiVariables(&result_py);
Austin Schuh0f881092024-06-28 15:36:48 -0700523
524 result_py.emplace_back("");
525 result_py.emplace_back(" result = casadi.SX.sym('result', 25, 1)");
526 result_py.emplace_back("");
527
528 // And then write out the derivative of each state.
529 for (size_t m = 0; m < kNumModules; ++m) {
530 result_py.emplace_back(
531 absl::Substitute(" result[$0, 0] = omegas$1", m * 4, m));
532 result_py.emplace_back(
533 absl::Substitute(" result[$0, 0] = omegad$1", m * 4 + 1, m));
534
535 result_py.emplace_back(absl::Substitute(
536 " result[$0, 0] = $1", m * 4 + 2, ccode(*modules_[m].alphas_eqn)));
537 result_py.emplace_back(absl::Substitute(
538 " result[$0, 0] = $1", m * 4 + 3, ccode(*modules_[m].alphad_eqn)));
539 }
540
541 result_py.emplace_back(
542 absl::Substitute(" result[$0, 0] = omega", kNumModules * 4));
543 result_py.emplace_back(
544 absl::Substitute(" result[$0, 0] = vx", kNumModules * 4 + 1));
545 result_py.emplace_back(
546 absl::Substitute(" result[$0, 0] = vy", kNumModules * 4 + 2));
547
548 result_py.emplace_back(absl::Substitute(
549 " result[$0, 0] = $1", kNumModules * 4 + 3, ccode(*angular_accel_)));
550 result_py.emplace_back(absl::Substitute(" result[$0, 0] = $1",
551 kNumModules * 4 + 4,
552 ccode(*accel_.get(0, 0))));
553 result_py.emplace_back(absl::Substitute(" result[$0, 0] = $1",
554 kNumModules * 4 + 5,
555 ccode(*accel_.get(1, 0))));
556
557 result_py.emplace_back(
558 absl::Substitute(" result[$0, 0] = 0.0", kNumModules * 4 + 6));
559 result_py.emplace_back(
560 absl::Substitute(" result[$0, 0] = 0.0", kNumModules * 4 + 7));
561 result_py.emplace_back(
562 absl::Substitute(" result[$0, 0] = 0.0", kNumModules * 4 + 8));
563
564 result_py.emplace_back("");
565 result_py.emplace_back(
566 " return casadi.Function('xdot', [X, U], [result])");
Austin Schuhb67a38f2024-07-04 13:48:38 -0700567 result_py.emplace_back("");
568 result_py.emplace_back(
569 "# Returns the velocity of the wheel in steer module coordinates.");
570 result_py.emplace_back("def wheel_ground_velocity(i, X, U):");
571 WriteCasadiVariables(&result_py);
572 result_py.emplace_back(
573 " result = casadi.SX.sym('ground_wheel_velocity', 2, 1)");
574
575 for (size_t m = 0; m < kNumModules; ++m) {
576 if (m == 0) {
577 result_py.emplace_back(" if i == 0:");
578 } else {
579 result_py.emplace_back(absl::Substitute(" elif i == $0:", m));
580 }
581 for (int j = 0; j < 2; ++j) {
582 result_py.emplace_back(absl::Substitute(
583 " result[$0, 0] = $1", j,
584 ccode(*modules_[m].wheel_ground_velocity.get(0, 0))));
585 }
586 }
587 result_py.emplace_back(" else:");
588 result_py.emplace_back(
589 " raise ValueError(\"Invalid module number\")");
590 result_py.emplace_back(" return result");
591
592 result_py.emplace_back("");
593 result_py.emplace_back("# Returns the slip angle of the ith wheel.");
594 result_py.emplace_back("def slip_angle(i, X, U):");
595 WriteCasadiVariables(&result_py);
596 for (size_t m = 0; m < kNumModules; ++m) {
597 if (m == 0) {
598 result_py.emplace_back(" if i == 0:");
599 } else {
600 result_py.emplace_back(absl::Substitute(" elif i == $0:", m));
601 }
602 result_py.emplace_back(absl::Substitute(" return $0",
603 ccode(*modules_[m].slip_angle)));
604 }
605 result_py.emplace_back(" raise ValueError(\"Invalid module number\")");
606 result_py.emplace_back("");
607 result_py.emplace_back("# Returns the slip ratio of the ith wheel.");
608 result_py.emplace_back("def slip_ratio(i, X, U):");
609 WriteCasadiVariables(&result_py);
610 for (size_t m = 0; m < kNumModules; ++m) {
611 if (m == 0) {
612 result_py.emplace_back(" if i == 0:");
613 } else {
614 result_py.emplace_back(absl::Substitute(" elif i == $0:", m));
615 }
616 result_py.emplace_back(absl::Substitute(" return $0",
617 ccode(*modules_[m].slip_ratio)));
618 }
619 result_py.emplace_back(" raise ValueError(\"Invalid module number\")");
620
621 result_py.emplace_back("");
622 result_py.emplace_back(
623 "# Returns the force on the wheel in steer module coordinates.");
624 result_py.emplace_back("def wheel_force(i, X, U):");
625 WriteCasadiVariables(&result_py);
626 result_py.emplace_back(" result = casadi.SX.sym('Fw', 2, 1)");
627
628 for (size_t m = 0; m < kNumModules; ++m) {
629 if (m == 0) {
630 result_py.emplace_back(" if i == 0:");
631 } else {
632 result_py.emplace_back(absl::Substitute(" elif i == $0:", m));
633 }
634 result_py.emplace_back(absl::Substitute(" result[0, 0] = $0",
635 ccode(*modules_[m].Fwx)));
636 result_py.emplace_back(absl::Substitute(" result[1, 0] = $0",
637 ccode(*modules_[m].Fwy)));
638 }
639 result_py.emplace_back(" else:");
640 result_py.emplace_back(
641 " raise ValueError(\"Invalid module number\")");
642 result_py.emplace_back(" return result");
Austin Schuh0f881092024-06-28 15:36:48 -0700643
644 aos::util::WriteStringToFileOrDie(py_path, absl::StrJoin(result_py, "\n"));
645 }
646
justinT21446e4f62024-06-16 22:36:10 -0700647 private:
648 static constexpr uint8_t kNumModules = 4;
649
650 Module ModulePhysics(const int m, DenseMatrix mounting_location) {
651 VLOG(1) << "Solving module " << m;
652
653 Module result;
654
655 result.Is = symbol(absl::StrFormat("Is%u", m));
656 result.Id = symbol(absl::StrFormat("Id%u", m));
657
658 RCP<const Symbol> thetamd = symbol(absl::StrFormat("theta_md%u", m));
659 RCP<const Symbol> omegamd = symbol(absl::StrFormat("omega_md%u", m));
660 RCP<const Symbol> alphamd = symbol(absl::StrFormat("alpha_md%u", m));
661
662 result.thetas = symbol(absl::StrFormat("thetas%u", m));
663 result.omegas = symbol(absl::StrFormat("omegas%u", m));
664 result.alphas = symbol(absl::StrFormat("alphas%u", m));
665
666 result.thetad = symbol(absl::StrFormat("thetad%u", m));
667 result.omegad = symbol(absl::StrFormat("omegad%u", m));
668 result.alphad = symbol(absl::StrFormat("alphad%u", m));
669
670 // Velocity of the module in field coordinates
justinT21942892b2024-07-02 22:33:50 -0700671 DenseMatrix robot_velocity = DenseMatrix(2, 1);
672 mul_dense_dense(R(theta_), DenseMatrix(2, 1, {vx_, vy_}), robot_velocity);
justinT21446e4f62024-06-16 22:36:10 -0700673 VLOG(1) << "robot velocity: " << robot_velocity.__str__();
674
675 // Velocity of the contact patch in field coordinates
676 DenseMatrix temp_matrix = DenseMatrix(2, 1);
677 DenseMatrix temp_matrix2 = DenseMatrix(2, 1);
678 DenseMatrix contact_patch_velocity = DenseMatrix(2, 1);
679
680 mul_dense_dense(R(theta_), mounting_location, temp_matrix);
681 add_dense_dense(angle_cross(temp_matrix, omega_), robot_velocity,
682 temp_matrix2);
683 mul_dense_dense(R(add(theta_, result.thetas)),
684 DenseMatrix(2, 1, {caster_, integer(0)}), temp_matrix);
685 add_dense_dense(temp_matrix2,
686 angle_cross(temp_matrix, add(omega_, result.omegas)),
687 contact_patch_velocity);
688
689 VLOG(1);
690 VLOG(1) << "contact patch velocity: " << contact_patch_velocity.__str__();
691
692 // Relative velocity of the surface of the wheel to the ground.
Austin Schuhb67a38f2024-07-04 13:48:38 -0700693 result.wheel_ground_velocity = DenseMatrix(2, 1);
justinT21446e4f62024-06-16 22:36:10 -0700694 mul_dense_dense(R(neg(add(result.thetas, theta_))), contact_patch_velocity,
Austin Schuhb67a38f2024-07-04 13:48:38 -0700695 result.wheel_ground_velocity);
justinT21446e4f62024-06-16 22:36:10 -0700696
697 VLOG(1);
Austin Schuhb67a38f2024-07-04 13:48:38 -0700698 VLOG(1) << "wheel ground velocity: "
699 << result.wheel_ground_velocity.__str__();
justinT21446e4f62024-06-16 22:36:10 -0700700
Austin Schuhb67a38f2024-07-04 13:48:38 -0700701 result.slip_angle = neg(atan2(result.wheel_ground_velocity.get(1, 0),
702 result.wheel_ground_velocity.get(0, 0)));
justinT21446e4f62024-06-16 22:36:10 -0700703
704 VLOG(1);
Austin Schuhb67a38f2024-07-04 13:48:38 -0700705 VLOG(1) << "slip angle: " << result.slip_angle->__str__();
justinT21446e4f62024-06-16 22:36:10 -0700706
Austin Schuhb67a38f2024-07-04 13:48:38 -0700707 result.slip_ratio = div(
708 sub(mul(r_w_, result.omegad), result.wheel_ground_velocity.get(0, 0)),
709 abs(result.wheel_ground_velocity.get(0, 0)));
justinT21446e4f62024-06-16 22:36:10 -0700710 VLOG(1);
Austin Schuhb67a38f2024-07-04 13:48:38 -0700711 VLOG(1) << "Slip ratio " << result.slip_ratio->__str__();
justinT21446e4f62024-06-16 22:36:10 -0700712
Austin Schuhb67a38f2024-07-04 13:48:38 -0700713 result.Fwx = simplify(mul(Cx_, result.slip_ratio));
714 result.Fwy = simplify(mul(Cy_, result.slip_angle));
justinT21446e4f62024-06-16 22:36:10 -0700715
716 RCP<const Basic> Ms =
Austin Schuhb67a38f2024-07-04 13:48:38 -0700717 mul(result.Fwy, add(div(contact_patch_length_, integer(3)), caster_));
justinT21446e4f62024-06-16 22:36:10 -0700718 VLOG(1);
719 VLOG(1) << "Ms " << Ms->__str__();
720 VLOG(1);
Austin Schuhb67a38f2024-07-04 13:48:38 -0700721 VLOG(1) << "Fwx " << result.Fwx->__str__();
justinT21446e4f62024-06-16 22:36:10 -0700722 VLOG(1);
Austin Schuhb67a38f2024-07-04 13:48:38 -0700723 VLOG(1) << "Fwy " << result.Fwy->__str__();
justinT21446e4f62024-06-16 22:36:10 -0700724
725 // alphas = ...
726 RCP<const Basic> lhms =
727 mul(add(neg(wb_), mul(add(rs_, rp_), sub(integer(1), div(rb1_, rp_)))),
Austin Schuhb67a38f2024-07-04 13:48:38 -0700728 mul(div(r_w_, rb2_), neg(result.Fwx)));
justinT21446e4f62024-06-16 22:36:10 -0700729 RCP<const Basic> lhs = add(add(Ms, div(mul(Jsm_, result.Is), Gs_)), lhms);
730 RCP<const Basic> rhs = add(Jsm_, div(div(Js_, Gs_), Gs_));
731 RCP<const Basic> accel_steer_eqn = simplify(div(lhs, rhs));
732
733 VLOG(1);
734 VLOG(1) << result.alphas->__str__() << " = " << accel_steer_eqn->__str__();
735
736 lhs = sub(mul(sub(div(add(rp_, rs_), rp_), integer(1)), result.omegas),
737 mul(Gd1_, mul(Gd2_, omegamd)));
738 RCP<const Basic> dplanitary_eqn = sub(mul(Gd3_, lhs), result.omegad);
739
740 lhs = sub(mul(sub(div(add(rp_, rs_), rp_), integer(1)), result.alphas),
741 mul(Gd1_, mul(Gd2_, alphamd)));
742 RCP<const Basic> ddplanitary_eqn = sub(mul(Gd3_, lhs), result.alphad);
743
744 RCP<const Basic> drive_eqn = sub(
745 add(mul(neg(Jdm_), div(alphamd, Gd_)), mul(Ktd_, div(result.Id, Gd_))),
Austin Schuhb67a38f2024-07-04 13:48:38 -0700746 mul(neg(result.Fwx), r_w_));
justinT21446e4f62024-06-16 22:36:10 -0700747
748 VLOG(1) << "drive_eqn: " << drive_eqn->__str__();
749
750 // Substitute in ddplanitary_eqn so we get rid of alphamd
751 map_basic_basic map;
752 RCP<const Set> reals = interval(NegInf, Inf, true, true);
753 RCP<const Set> solve_solution = solve(ddplanitary_eqn, alphamd, reals);
754 map[alphamd] = solve_solution->get_args()[1]->get_args()[0];
755 VLOG(1) << "temp: " << solve_solution->__str__();
756 RCP<const Basic> drive_eqn_subs = drive_eqn->subs(map);
757
758 map.clear();
759 map[result.alphas] = accel_steer_eqn;
760 RCP<const Basic> drive_eqn_subs2 = drive_eqn_subs->subs(map);
761 RCP<const Basic> drive_eqn_subs3 = simplify(drive_eqn_subs2);
762 VLOG(1) << "drive_eqn simplified: " << drive_eqn_subs3->__str__();
763
764 solve_solution = solve(drive_eqn_subs3, result.alphad, reals);
765
766 RCP<const Basic> drive_accel =
767 simplify(solve_solution->get_args()[1]->get_args()[0]);
768 VLOG(1) << "drive_accel: " << drive_accel->__str__();
769
770 DenseMatrix mat_output = DenseMatrix(2, 1);
771 mul_dense_dense(R(add(theta_, result.thetas)),
Austin Schuhb67a38f2024-07-04 13:48:38 -0700772 DenseMatrix(2, 1, {result.Fwx, result.Fwy}), mat_output);
justinT21446e4f62024-06-16 22:36:10 -0700773
774 // Comput the resulting force from the module.
775 DenseMatrix F = mat_output;
776
777 RCP<const Basic> torque = simplify(force_cross(mounting_location, F));
778 result.accel = DenseMatrix(2, 1);
779 mul_dense_scalar(F, pow(m_, minus_one), result.accel);
780 result.angular_accel = div(torque, J_);
781 VLOG(1);
782 VLOG(1) << "angular_accel = " << result.angular_accel->__str__();
783
784 VLOG(1);
785 VLOG(1) << "accel(0, 0) = " << result.accel.get(0, 0)->__str__();
786 VLOG(1);
787 VLOG(1) << "accel(1, 0) = " << result.accel.get(1, 0)->__str__();
788
789 result.alphad_eqn = drive_accel;
790 result.alphas_eqn = accel_steer_eqn;
791 return result;
792 }
793
794 DenseMatrix R(const RCP<const Basic> theta) {
795 return DenseMatrix(2, 2,
796 {cos(theta), neg(sin(theta)), sin(theta), cos(theta)});
797 }
798
799 DenseMatrix angle_cross(DenseMatrix a, RCP<const Basic> b) {
800 return DenseMatrix(2, 1, {mul(a.get(1, 0), b), mul(neg(a.get(0, 0)), b)});
801 }
802
803 RCP<const Basic> force_cross(DenseMatrix r, DenseMatrix f) {
804 return sub(mul(r.get(0, 0), f.get(1, 0)), mul(r.get(1, 0), f.get(0, 0)));
805 }
806
807 // z represents the number of teeth per gear, theta is the angle between
808 // shafts(in degrees), D_02 is the pitch diameter of gear 2 and b_2 is the
809 // length of the tooth of gear 2
810 // returns std::pair(r_01, r_02)
811 std::pair<double, double> GetBevelPitchRadius(double z1, double z2,
812 double theta, double D_02,
813 double b_2) {
814 double gamma_1 = std::atan2(z1, z2);
815 double gamma_2 = theta / 180.0 * std::numbers::pi - gamma_1;
816 double R_m = D_02 / 2 / std::sin(gamma_2) - b_2 / 2;
817 return std::pair(R_m * std::cos(gamma_2), R_m * std::sin(gamma_2));
818 }
819
820 Motor drive_motor_;
821 Motor steer_motor_;
822
823 RCP<const Basic> Cx_;
824 RCP<const Basic> Cy_;
825 RCP<const Basic> r_w_;
826 RCP<const Basic> m_;
827 RCP<const Basic> J_;
828 RCP<const Basic> Gd1_;
829 RCP<const Basic> rs_;
830 RCP<const Basic> rp_;
831 RCP<const Basic> Gd2_;
832 RCP<const Basic> rb1_;
833 RCP<const Basic> rb2_;
834 RCP<const Basic> Gd3_;
835 RCP<const Basic> Gd_;
836 RCP<const Basic> Js_;
837 RCP<const Basic> Gs_;
838 RCP<const Basic> wb_;
839 RCP<const Basic> Jdm_;
840 RCP<const Basic> Jsm_;
841 RCP<const Basic> Kts_;
842 RCP<const Basic> Ktd_;
843 RCP<const Basic> robot_width_;
844 RCP<const Basic> caster_;
845 RCP<const Basic> contact_patch_length_;
846 RCP<const Basic> x_;
847 RCP<const Basic> y_;
848 RCP<const Basic> theta_;
849 RCP<const Basic> vx_;
850 RCP<const Basic> vy_;
851 RCP<const Basic> omega_;
852 RCP<const Basic> ax_;
853 RCP<const Basic> ay_;
854 RCP<const Basic> atheta_;
855
856 std::array<Module, kNumModules> modules_;
857
858 DenseMatrix accel_;
859 RCP<const Basic> angular_accel_;
860};
861
862} // namespace frc971::control_loops::swerve
863
864int main(int argc, char **argv) {
865 aos::InitGoogle(&argc, &argv);
866
867 frc971::control_loops::swerve::SwerveSimulation sim;
868
Austin Schuh0f881092024-06-28 15:36:48 -0700869 if (!FLAGS_cc_output_path.empty() && !FLAGS_h_output_path.empty()) {
justinT21446e4f62024-06-16 22:36:10 -0700870 sim.Write(FLAGS_cc_output_path, FLAGS_h_output_path);
Austin Schuh0f881092024-06-28 15:36:48 -0700871 }
872 if (!FLAGS_py_output_path.empty()) {
justinT21942892b2024-07-02 22:33:50 -0700873 sim.WritePy(FLAGS_py_output_path);
justinT21446e4f62024-06-16 22:36:10 -0700874 }
Austin Schuh0f881092024-06-28 15:36:48 -0700875 if (!FLAGS_casadi_py_output_path.empty()) {
876 sim.WriteCasadi(FLAGS_casadi_py_output_path);
877 }
justinT21446e4f62024-06-16 22:36:10 -0700878
879 return 0;
880}