blob: 6f23c47d59a2af57c3d0da161dcaf93922fdbac2 [file] [log] [blame]
justinT21446e4f62024-06-16 22:36:10 -07001#include <symengine/add.h>
2#include <symengine/matrix.h>
3#include <symengine/number.h>
4#include <symengine/printers.h>
5#include <symengine/real_double.h>
6#include <symengine/simplify.h>
7#include <symengine/solve.h>
8#include <symengine/symbol.h>
9
10#include <array>
11#include <cmath>
12#include <numbers>
13#include <utility>
14
15#include "absl/strings/str_format.h"
16#include "absl/strings/str_join.h"
17#include "absl/strings/str_replace.h"
18#include "absl/strings/substitute.h"
19#include "gflags/gflags.h"
20#include "glog/logging.h"
21
22#include "aos/init.h"
23#include "aos/util/file.h"
24#include "frc971/control_loops/swerve/motors.h"
25
26DEFINE_string(output_base, "",
27 "Path to strip off the front of the output paths.");
28DEFINE_string(cc_output_path, "", "Path to write generated header code to");
29DEFINE_string(h_output_path, "", "Path to write generated cc code to");
justinT21942892b2024-07-02 22:33:50 -070030DEFINE_string(py_output_path, "", "Path to write generated py code to");
Austin Schuh0f881092024-06-28 15:36:48 -070031DEFINE_string(casadi_py_output_path, "",
32 "Path to write casadi generated py code to");
justinT21446e4f62024-06-16 22:36:10 -070033
34DEFINE_bool(symbolic, false, "If true, write everything out symbolically.");
35
justinT21942892b2024-07-02 22:33:50 -070036using SymEngine::abs;
justinT21446e4f62024-06-16 22:36:10 -070037using SymEngine::add;
38using SymEngine::atan2;
39using SymEngine::Basic;
40using SymEngine::ccode;
41using SymEngine::cos;
42using SymEngine::DenseMatrix;
43using SymEngine::div;
44using SymEngine::Inf;
45using SymEngine::integer;
46using SymEngine::map_basic_basic;
47using SymEngine::minus_one;
48using SymEngine::neg;
49using SymEngine::NegInf;
50using SymEngine::pow;
51using SymEngine::RCP;
52using SymEngine::real_double;
53using SymEngine::RealDouble;
54using SymEngine::Set;
55using SymEngine::simplify;
56using SymEngine::sin;
57using SymEngine::solve;
58using SymEngine::symbol;
59using SymEngine::Symbol;
60
61namespace frc971::control_loops::swerve {
62
63// State per module.
64struct Module {
65 RCP<const Symbol> Is;
66
67 RCP<const Symbol> Id;
68
69 RCP<const Symbol> thetas;
70 RCP<const Symbol> omegas;
71 RCP<const Symbol> alphas;
72 RCP<const Basic> alphas_eqn;
73
74 RCP<const Symbol> thetad;
75 RCP<const Symbol> omegad;
76 RCP<const Symbol> alphad;
77 RCP<const Basic> alphad_eqn;
78
79 // Acceleration contribution from this module.
80 DenseMatrix accel;
81 RCP<const Basic> angular_accel;
82};
83
84class SwerveSimulation {
85 public:
86 SwerveSimulation() : drive_motor_(KrakenFOC()), steer_motor_(KrakenFOC()) {
87 auto fx = symbol("fx");
88 auto fy = symbol("fy");
89 auto moment = symbol("moment");
90
91 if (FLAGS_symbolic) {
92 Cx_ = symbol("Cx");
93 Cy_ = symbol("Cy");
94
95 r_w_ = symbol("r_w_");
96
97 m_ = symbol("m");
98 J_ = symbol("J");
99
100 Gd1_ = symbol("Gd1");
101 rs_ = symbol("rs");
102 rp_ = symbol("rp");
103 Gd2_ = symbol("Gd2");
104
105 rb1_ = symbol("rb1");
106 rb2_ = symbol("rb2");
107
108 Gd2_ = symbol("Gd3");
109 Gd_ = symbol("Gd");
110
111 Js_ = symbol("Js");
112
113 Gs_ = symbol("Gs");
114 wb_ = symbol("wb");
115
116 Jdm_ = symbol("Jdm");
117 Jsm_ = symbol("Jsm");
118 Kts_ = symbol("Kts");
119 Ktd_ = symbol("Ktd");
120
121 robot_width_ = symbol("robot_width");
122
123 caster_ = symbol("caster");
124 contact_patch_length_ = symbol("Lcp");
125 } else {
126 Cx_ = real_double(5 * 9.8 / 0.05 / 4.0);
127 Cy_ = real_double(5 * 9.8 / 0.05 / 4.0);
128
129 r_w_ = real_double(2 * 0.0254);
130
131 m_ = real_double(25.0); // base is 20 kg without battery
132 J_ = real_double(6.0);
133
134 Gd1_ = real_double(12.0 / 42.0);
135 rs_ = real_double(28.0 / 20.0 / 2.0);
136 rp_ = real_double(18.0 / 20.0 / 2.0);
137 Gd2_ = div(rs_, rp_);
138
139 // 15 / 45 bevel ratio, calculated using python script ported over to
140 // GetBevelPitchRadius(double
141 // TODO(Justin): Use the function instead of computed constantss
142 rb1_ = real_double(0.3805473);
143 rb2_ = real_double(1.14164);
144
145 Gd3_ = div(rb1_, rb2_);
146 Gd_ = mul(mul(Gd1_, Gd2_), Gd3_);
147
148 Js_ = real_double(0.1);
149
150 Gs_ = real_double(35.0 / 468.0);
151 wb_ = real_double(0.725);
152
153 Jdm_ = real_double(drive_motor_.motor_inertia);
154 Jsm_ = real_double(steer_motor_.motor_inertia);
155 Kts_ = real_double(steer_motor_.Kt);
156 Ktd_ = real_double(drive_motor_.Kt);
157
158 robot_width_ = real_double(24.75 * 0.0254);
159
160 caster_ = real_double(0.01);
161 contact_patch_length_ = real_double(0.02);
162 }
163
164 x_ = symbol("x");
165 y_ = symbol("y");
166 theta_ = symbol("theta");
167
168 vx_ = symbol("vx");
169 vy_ = symbol("vy");
170 omega_ = symbol("omega");
171
172 ax_ = symbol("ax");
173 ay_ = symbol("ay");
174 atheta_ = symbol("atheta");
175
176 // Now, compute the accelerations due to the disturbance forces.
177 angular_accel_ = div(moment, J_);
178 DenseMatrix external_accel = DenseMatrix(2, 1, {div(fx, m_), div(fy, m_)});
179
180 // And compute the physics contributions from each module.
181 modules_[0] = ModulePhysics(
182 0, DenseMatrix(
183 2, 1,
184 {div(robot_width_, integer(2)), div(robot_width_, integer(2))}));
185 modules_[1] =
186 ModulePhysics(1, DenseMatrix(2, 1,
187 {div(robot_width_, integer(-2)),
188 div(robot_width_, integer(2))}));
189 modules_[2] =
190 ModulePhysics(2, DenseMatrix(2, 1,
191 {div(robot_width_, integer(-2)),
192 div(robot_width_, integer(-2))}));
193 modules_[3] =
194 ModulePhysics(3, DenseMatrix(2, 1,
195 {div(robot_width_, integer(2)),
196 div(robot_width_, integer(-2))}));
197
198 // And convert them into the overall robot contribution.
199 DenseMatrix temp0 = DenseMatrix(2, 1);
200 DenseMatrix temp1 = DenseMatrix(2, 1);
201 DenseMatrix temp2 = DenseMatrix(2, 1);
202 accel_ = DenseMatrix(2, 1);
203
204 add_dense_dense(modules_[0].accel, external_accel, temp0);
205 add_dense_dense(temp0, modules_[1].accel, temp1);
206 add_dense_dense(temp1, modules_[2].accel, temp2);
207 add_dense_dense(temp2, modules_[3].accel, accel_);
208
209 angular_accel_ = add(angular_accel_, modules_[0].angular_accel);
210 angular_accel_ = add(angular_accel_, modules_[1].angular_accel);
211 angular_accel_ = add(angular_accel_, modules_[2].angular_accel);
212 angular_accel_ = simplify(add(angular_accel_, modules_[3].angular_accel));
213
214 VLOG(1) << "accel(0, 0) = " << ccode(*accel_.get(0, 0));
215 VLOG(1) << "accel(1, 0) = " << ccode(*accel_.get(1, 0));
216 VLOG(1) << "angular_accel = " << ccode(*angular_accel_);
217 }
218
justinT21942892b2024-07-02 22:33:50 -0700219 // Writes the physics out to the provided .py path.
220 void WritePy(std::string_view py_path) {
221 std::vector<std::string> result_py;
222
223 result_py.emplace_back("#!/usr/bin/python3");
224 result_py.emplace_back("");
225 result_py.emplace_back("import numpy");
justinT21942892b2024-07-02 22:33:50 -0700226 result_py.emplace_back("");
227
justinT21942892b2024-07-02 22:33:50 -0700228 result_py.emplace_back("def swerve_physics(t, X, U_func):");
Austin Schuh0f881092024-06-28 15:36:48 -0700229 result_py.emplace_back(" def atan2(y, x):");
230 result_py.emplace_back(" if x < 0:");
231 result_py.emplace_back(" return -numpy.atan2(y, x)");
232 result_py.emplace_back(" else:");
233 result_py.emplace_back(" return numpy.atan2(y, x)");
234 result_py.emplace_back(" sin = numpy.sin");
235 result_py.emplace_back(" cos = numpy.cos");
236 result_py.emplace_back(" fabs = numpy.fabs");
237
justinT21942892b2024-07-02 22:33:50 -0700238 result_py.emplace_back(" result = numpy.empty([25, 1])");
239 result_py.emplace_back(" X = X.reshape(25, 1)");
240 result_py.emplace_back(" U = U_func(X)");
241 result_py.emplace_back("");
242
243 // Start by writing out variables matching each of the symbol names we use
244 // so we don't have to modify the computed equations too much.
245 for (size_t m = 0; m < kNumModules; ++m) {
246 result_py.emplace_back(
247 absl::Substitute(" thetas$0 = X[$1, 0]", m, m * 4));
248 result_py.emplace_back(
249 absl::Substitute(" omegas$0 = X[$1, 0]", m, m * 4 + 2));
250 result_py.emplace_back(
251 absl::Substitute(" omegad$0 = X[$1, 0]", m, m * 4 + 3));
252 }
253
254 result_py.emplace_back(
255 absl::Substitute(" theta = X[$0, 0]", kNumModules * 4 + 2));
256 result_py.emplace_back(
257 absl::Substitute(" vx = X[$0, 0]", kNumModules * 4 + 3));
258 result_py.emplace_back(
259 absl::Substitute(" vy = X[$0, 0]", kNumModules * 4 + 4));
260 result_py.emplace_back(
261 absl::Substitute(" omega = X[$0, 0]", kNumModules * 4 + 5));
262
263 result_py.emplace_back(
264 absl::Substitute(" fx = X[$0, 0]", kNumModules * 4 + 6));
265 result_py.emplace_back(
266 absl::Substitute(" fy = X[$0, 0]", kNumModules * 4 + 7));
267 result_py.emplace_back(
268 absl::Substitute(" moment = X[$0, 0]", kNumModules * 4 + 8));
269
270 // Now do the same for the inputs.
271 for (size_t m = 0; m < kNumModules; ++m) {
272 result_py.emplace_back(absl::Substitute(" Is$0 = U[$1, 0]", m, m * 2));
273 result_py.emplace_back(
274 absl::Substitute(" Id$0 = U[$1, 0]", m, m * 2 + 1));
275 }
276
277 result_py.emplace_back("");
278
279 // And then write out the derivative of each state.
280 for (size_t m = 0; m < kNumModules; ++m) {
281 result_py.emplace_back(
282 absl::Substitute(" result[$0, 0] = omegas$1", m * 4, m));
283 result_py.emplace_back(
284 absl::Substitute(" result[$0, 0] = omegad$1", m * 4 + 1, m));
285
286 result_py.emplace_back(absl::Substitute(
287 " result[$0, 0] = $1", m * 4 + 2, ccode(*modules_[m].alphas_eqn)));
288 result_py.emplace_back(absl::Substitute(
289 " result[$0, 0] = $1", m * 4 + 3, ccode(*modules_[m].alphad_eqn)));
290 }
291
292 result_py.emplace_back(
293 absl::Substitute(" result[$0, 0] = vx", kNumModules * 4));
294 result_py.emplace_back(
295 absl::Substitute(" result[$0, 0] = vy", kNumModules * 4 + 1));
296 result_py.emplace_back(
297 absl::Substitute(" result[$0, 0] = omega", kNumModules * 4 + 2));
298
299 result_py.emplace_back(absl::Substitute(" result[$0, 0] = $1",
300 kNumModules * 4 + 3,
301 ccode(*accel_.get(0, 0))));
302 result_py.emplace_back(absl::Substitute(" result[$0, 0] = $1",
303 kNumModules * 4 + 4,
304 ccode(*accel_.get(1, 0))));
305 result_py.emplace_back(absl::Substitute(
306 " result[$0, 0] = $1", kNumModules * 4 + 5, ccode(*angular_accel_)));
307
308 result_py.emplace_back(
309 absl::Substitute(" result[$0, 0] = 0.0", kNumModules * 4 + 6));
310 result_py.emplace_back(
311 absl::Substitute(" result[$0, 0] = 0.0", kNumModules * 4 + 7));
312 result_py.emplace_back(
313 absl::Substitute(" result[$0, 0] = 0.0", kNumModules * 4 + 8));
314
315 result_py.emplace_back("");
316 result_py.emplace_back(" return result.reshape(25,)\n");
317
318 aos::util::WriteStringToFileOrDie(py_path, absl::StrJoin(result_py, "\n"));
319 }
320
justinT21446e4f62024-06-16 22:36:10 -0700321 // Writes the physics out to the provided .cc and .h path.
322 void Write(std::string_view cc_path, std::string_view h_path) {
323 std::vector<std::string> result_cc;
324 std::vector<std::string> result_h;
325
Austin Schuh0f881092024-06-28 15:36:48 -0700326 std::string_view include_guard_stripped = h_path;
justinT21446e4f62024-06-16 22:36:10 -0700327 CHECK(absl::ConsumePrefix(&include_guard_stripped, FLAGS_output_base));
328 std::string include_guard =
329 absl::StrReplaceAll(absl::AsciiStrToUpper(include_guard_stripped),
330 {{"/", "_"}, {".", "_"}});
331
332 // Write out the header.
333 result_h.emplace_back(absl::Substitute("#ifndef $0_", include_guard));
334 result_h.emplace_back(absl::Substitute("#define $0_", include_guard));
335 result_h.emplace_back("");
336 result_h.emplace_back("#include <Eigen/Dense>");
337 result_h.emplace_back("");
338 result_h.emplace_back("namespace frc971::control_loops::swerve {");
339 result_h.emplace_back("");
340 result_h.emplace_back("// Returns the derivative of our state vector");
341 result_h.emplace_back("// [thetas0, thetad0, omegas0, omegad0,");
342 result_h.emplace_back("// thetas1, thetad1, omegas1, omegad1,");
343 result_h.emplace_back("// thetas2, thetad2, omegas2, omegad2,");
344 result_h.emplace_back("// thetas3, thetad3, omegas3, omegad3,");
345 result_h.emplace_back("// x, y, theta, vx, vy, omega,");
346 result_h.emplace_back("// Fx, Fy, Moment]");
347 result_h.emplace_back("Eigen::Matrix<double, 25, 1> SwervePhysics(");
348 result_h.emplace_back(
349 " Eigen::Map<const Eigen::Matrix<double, 25, 1>> X,");
350 result_h.emplace_back(
351 " Eigen::Map<const Eigen::Matrix<double, 8, 1>> U);");
352 result_h.emplace_back("");
353 result_h.emplace_back("} // namespace frc971::control_loops::swerve");
354 result_h.emplace_back("");
355 result_h.emplace_back(absl::Substitute("#endif // $0_", include_guard));
356
357 // Write out the .cc
358 result_cc.emplace_back(
359 absl::Substitute("#include \"$0\"", include_guard_stripped));
360 result_cc.emplace_back("");
361 result_cc.emplace_back("#include <cmath>");
362 result_cc.emplace_back("");
363 result_cc.emplace_back("namespace frc971::control_loops::swerve {");
364 result_cc.emplace_back("");
365 result_cc.emplace_back("Eigen::Matrix<double, 25, 1> SwervePhysics(");
366 result_cc.emplace_back(
367 " Eigen::Map<const Eigen::Matrix<double, 25, 1>> X,");
368 result_cc.emplace_back(
369 " Eigen::Map<const Eigen::Matrix<double, 8, 1>> U) {");
370 result_cc.emplace_back(" Eigen::Matrix<double, 25, 1> result;");
371
372 // Start by writing out variables matching each of the symbol names we use
373 // so we don't have to modify the computed equations too much.
374 for (size_t m = 0; m < kNumModules; ++m) {
375 result_cc.emplace_back(
376 absl::Substitute(" const double thetas$0 = X($1, 0);", m, m * 4));
377 result_cc.emplace_back(absl::Substitute(
378 " const double omegas$0 = X($1, 0);", m, m * 4 + 2));
379 result_cc.emplace_back(absl::Substitute(
380 " const double omegad$0 = X($1, 0);", m, m * 4 + 3));
381 }
382
383 result_cc.emplace_back(absl::Substitute(" const double theta = X($0, 0);",
384 kNumModules * 4 + 2));
385 result_cc.emplace_back(
386 absl::Substitute(" const double vx = X($0, 0);", kNumModules * 4 + 3));
387 result_cc.emplace_back(
388 absl::Substitute(" const double vy = X($0, 0);", kNumModules * 4 + 4));
389 result_cc.emplace_back(absl::Substitute(" const double omega = X($0, 0);",
390 kNumModules * 4 + 5));
391
392 result_cc.emplace_back(
393 absl::Substitute(" const double fx = X($0, 0);", kNumModules * 4 + 6));
394 result_cc.emplace_back(
395 absl::Substitute(" const double fy = X($0, 0);", kNumModules * 4 + 7));
396 result_cc.emplace_back(absl::Substitute(" const double moment = X($0, 0);",
397 kNumModules * 4 + 8));
398
399 // Now do the same for the inputs.
400 for (size_t m = 0; m < kNumModules; ++m) {
401 result_cc.emplace_back(
402 absl::Substitute(" const double Is$0 = U($1, 0);", m, m * 2));
403 result_cc.emplace_back(
404 absl::Substitute(" const double Id$0 = U($1, 0);", m, m * 2 + 1));
405 }
406
407 result_cc.emplace_back("");
408
409 // And then write out the derivative of each state.
410 for (size_t m = 0; m < kNumModules; ++m) {
411 result_cc.emplace_back(
412 absl::Substitute(" result($0, 0) = omegas$1;", m * 4, m));
413 result_cc.emplace_back(
414 absl::Substitute(" result($0, 0) = omegad$1;", m * 4 + 1, m));
415
416 result_cc.emplace_back(absl::Substitute(
417 " result($0, 0) = $1;", m * 4 + 2, ccode(*modules_[m].alphas_eqn)));
418 result_cc.emplace_back(absl::Substitute(
419 " result($0, 0) = $1;", m * 4 + 3, ccode(*modules_[m].alphad_eqn)));
420 }
421
422 result_cc.emplace_back(
423 absl::Substitute(" result($0, 0) = omega;", kNumModules * 4));
424 result_cc.emplace_back(
425 absl::Substitute(" result($0, 0) = vx;", kNumModules * 4 + 1));
426 result_cc.emplace_back(
427 absl::Substitute(" result($0, 0) = vy;", kNumModules * 4 + 2));
428
429 result_cc.emplace_back(absl::Substitute(
430 " result($0, 0) = $1;", kNumModules * 4 + 3, ccode(*angular_accel_)));
431 result_cc.emplace_back(absl::Substitute(" result($0, 0) = $1;",
432 kNumModules * 4 + 4,
433 ccode(*accel_.get(0, 0))));
434 result_cc.emplace_back(absl::Substitute(" result($0, 0) = $1;",
435 kNumModules * 4 + 5,
436 ccode(*accel_.get(1, 0))));
437
438 result_cc.emplace_back(
439 absl::Substitute(" result($0, 0) = 0.0;", kNumModules * 4 + 6));
440 result_cc.emplace_back(
441 absl::Substitute(" result($0, 0) = 0.0;", kNumModules * 4 + 7));
442 result_cc.emplace_back(
443 absl::Substitute(" result($0, 0) = 0.0;", kNumModules * 4 + 8));
444
445 result_cc.emplace_back("");
446 result_cc.emplace_back(" return result;");
447 result_cc.emplace_back("}");
448 result_cc.emplace_back("");
449 result_cc.emplace_back("} // namespace frc971::control_loops::swerve");
450
451 aos::util::WriteStringToFileOrDie(cc_path, absl::StrJoin(result_cc, "\n"));
452 aos::util::WriteStringToFileOrDie(h_path, absl::StrJoin(result_h, "\n"));
453 }
454
Austin Schuh0f881092024-06-28 15:36:48 -0700455 // Writes the physics out to the provided .cc and .h path.
456 void WriteCasadi(std::string_view py_path) {
457 std::vector<std::string> result_py;
458
459 // Write out the header.
460 result_py.emplace_back("#!/usr/bin/python3");
461 result_py.emplace_back("");
462 result_py.emplace_back("import casadi");
463 result_py.emplace_back("");
464 result_py.emplace_back("# Returns the derivative of our state vector");
465 result_py.emplace_back("# Returns the derivative of our state vector");
466 result_py.emplace_back("# [thetas0, thetad0, omegas0, omegad0,");
467 result_py.emplace_back("# thetas1, thetad1, omegas1, omegad1,");
468 result_py.emplace_back("# thetas2, thetad2, omegas2, omegad2,");
469 result_py.emplace_back("# thetas3, thetad3, omegas3, omegad3,");
470 result_py.emplace_back("# x, y, theta, vx, vy, omega,");
471 result_py.emplace_back("# Fx, Fy, Moment]");
472 result_py.emplace_back("def swerve_physics(X, U):");
473 result_py.emplace_back(" sin = casadi.sin");
474 result_py.emplace_back(" cos = casadi.cos");
475 result_py.emplace_back(" atan2 = casadi.atan2");
476
477 // Start by writing out variables matching each of the symbol names we use
478 // so we don't have to modify the computed equations too much.
479 for (size_t m = 0; m < kNumModules; ++m) {
480 result_py.emplace_back(
481 absl::Substitute(" thetas$0 = X[$1, 0]", m, m * 4));
482 result_py.emplace_back(
483 absl::Substitute(" omegas$0 = X[$1, 0]", m, m * 4 + 2));
484 result_py.emplace_back(
485 absl::Substitute(" omegad$0 = X[$1, 0]", m, m * 4 + 3));
486 }
487
488 result_py.emplace_back(
489 absl::Substitute(" theta = X[$0, 0]", kNumModules * 4 + 2));
490 result_py.emplace_back(
491 absl::Substitute(" vx = X[$0, 0]", kNumModules * 4 + 3));
492 result_py.emplace_back(
493 absl::Substitute(" vy = X[$0, 0]", kNumModules * 4 + 4));
494 result_py.emplace_back(
495 absl::Substitute(" omega = X[$0, 0]", kNumModules * 4 + 5));
496
497 result_py.emplace_back(
498 absl::Substitute(" fx = X[$0, 0]", kNumModules * 4 + 6));
499 result_py.emplace_back(
500 absl::Substitute(" fy = X[$0, 0]", kNumModules * 4 + 7));
501 result_py.emplace_back(
502 absl::Substitute(" moment = X[$0, 0]", kNumModules * 4 + 8));
503
504 // Now do the same for the inputs.
505 for (size_t m = 0; m < kNumModules; ++m) {
506 result_py.emplace_back(absl::Substitute(" Is$0 = U[$1, 0]", m, m * 2));
507 result_py.emplace_back(
508 absl::Substitute(" Id$0 = U[$1, 0]", m, m * 2 + 1));
509 }
510
511 result_py.emplace_back("");
512 result_py.emplace_back(" result = casadi.SX.sym('result', 25, 1)");
513 result_py.emplace_back("");
514
515 // And then write out the derivative of each state.
516 for (size_t m = 0; m < kNumModules; ++m) {
517 result_py.emplace_back(
518 absl::Substitute(" result[$0, 0] = omegas$1", m * 4, m));
519 result_py.emplace_back(
520 absl::Substitute(" result[$0, 0] = omegad$1", m * 4 + 1, m));
521
522 result_py.emplace_back(absl::Substitute(
523 " result[$0, 0] = $1", m * 4 + 2, ccode(*modules_[m].alphas_eqn)));
524 result_py.emplace_back(absl::Substitute(
525 " result[$0, 0] = $1", m * 4 + 3, ccode(*modules_[m].alphad_eqn)));
526 }
527
528 result_py.emplace_back(
529 absl::Substitute(" result[$0, 0] = omega", kNumModules * 4));
530 result_py.emplace_back(
531 absl::Substitute(" result[$0, 0] = vx", kNumModules * 4 + 1));
532 result_py.emplace_back(
533 absl::Substitute(" result[$0, 0] = vy", kNumModules * 4 + 2));
534
535 result_py.emplace_back(absl::Substitute(
536 " result[$0, 0] = $1", kNumModules * 4 + 3, ccode(*angular_accel_)));
537 result_py.emplace_back(absl::Substitute(" result[$0, 0] = $1",
538 kNumModules * 4 + 4,
539 ccode(*accel_.get(0, 0))));
540 result_py.emplace_back(absl::Substitute(" result[$0, 0] = $1",
541 kNumModules * 4 + 5,
542 ccode(*accel_.get(1, 0))));
543
544 result_py.emplace_back(
545 absl::Substitute(" result[$0, 0] = 0.0", kNumModules * 4 + 6));
546 result_py.emplace_back(
547 absl::Substitute(" result[$0, 0] = 0.0", kNumModules * 4 + 7));
548 result_py.emplace_back(
549 absl::Substitute(" result[$0, 0] = 0.0", kNumModules * 4 + 8));
550
551 result_py.emplace_back("");
552 result_py.emplace_back(
553 " return casadi.Function('xdot', [X, U], [result])");
554
555 aos::util::WriteStringToFileOrDie(py_path, absl::StrJoin(result_py, "\n"));
556 }
557
justinT21446e4f62024-06-16 22:36:10 -0700558 private:
559 static constexpr uint8_t kNumModules = 4;
560
561 Module ModulePhysics(const int m, DenseMatrix mounting_location) {
562 VLOG(1) << "Solving module " << m;
563
564 Module result;
565
566 result.Is = symbol(absl::StrFormat("Is%u", m));
567 result.Id = symbol(absl::StrFormat("Id%u", m));
568
569 RCP<const Symbol> thetamd = symbol(absl::StrFormat("theta_md%u", m));
570 RCP<const Symbol> omegamd = symbol(absl::StrFormat("omega_md%u", m));
571 RCP<const Symbol> alphamd = symbol(absl::StrFormat("alpha_md%u", m));
572
573 result.thetas = symbol(absl::StrFormat("thetas%u", m));
574 result.omegas = symbol(absl::StrFormat("omegas%u", m));
575 result.alphas = symbol(absl::StrFormat("alphas%u", m));
576
577 result.thetad = symbol(absl::StrFormat("thetad%u", m));
578 result.omegad = symbol(absl::StrFormat("omegad%u", m));
579 result.alphad = symbol(absl::StrFormat("alphad%u", m));
580
581 // Velocity of the module in field coordinates
justinT21942892b2024-07-02 22:33:50 -0700582 DenseMatrix robot_velocity = DenseMatrix(2, 1);
583 mul_dense_dense(R(theta_), DenseMatrix(2, 1, {vx_, vy_}), robot_velocity);
justinT21446e4f62024-06-16 22:36:10 -0700584 VLOG(1) << "robot velocity: " << robot_velocity.__str__();
585
586 // Velocity of the contact patch in field coordinates
587 DenseMatrix temp_matrix = DenseMatrix(2, 1);
588 DenseMatrix temp_matrix2 = DenseMatrix(2, 1);
589 DenseMatrix contact_patch_velocity = DenseMatrix(2, 1);
590
591 mul_dense_dense(R(theta_), mounting_location, temp_matrix);
592 add_dense_dense(angle_cross(temp_matrix, omega_), robot_velocity,
593 temp_matrix2);
594 mul_dense_dense(R(add(theta_, result.thetas)),
595 DenseMatrix(2, 1, {caster_, integer(0)}), temp_matrix);
596 add_dense_dense(temp_matrix2,
597 angle_cross(temp_matrix, add(omega_, result.omegas)),
598 contact_patch_velocity);
599
600 VLOG(1);
601 VLOG(1) << "contact patch velocity: " << contact_patch_velocity.__str__();
602
603 // Relative velocity of the surface of the wheel to the ground.
604 DenseMatrix wheel_ground_velocity = DenseMatrix(2, 1);
605 mul_dense_dense(R(neg(add(result.thetas, theta_))), contact_patch_velocity,
606 wheel_ground_velocity);
607
608 VLOG(1);
609 VLOG(1) << "wheel ground velocity: " << wheel_ground_velocity.__str__();
610
justinT21942892b2024-07-02 22:33:50 -0700611 RCP<const Basic> slip_angle = neg(atan2(wheel_ground_velocity.get(1, 0),
612 wheel_ground_velocity.get(0, 0)));
justinT21446e4f62024-06-16 22:36:10 -0700613
614 VLOG(1);
615 VLOG(1) << "slip angle: " << slip_angle->__str__();
616
617 RCP<const Basic> slip_ratio =
618 div(sub(mul(r_w_, result.omegad), wheel_ground_velocity.get(0, 0)),
justinT21942892b2024-07-02 22:33:50 -0700619 abs(wheel_ground_velocity.get(0, 0)));
justinT21446e4f62024-06-16 22:36:10 -0700620 VLOG(1);
621 VLOG(1) << "Slip ratio " << slip_ratio->__str__();
622
623 RCP<const Basic> Fwx = simplify(mul(Cx_, slip_ratio));
624 RCP<const Basic> Fwy = simplify(mul(Cy_, slip_angle));
625
626 RCP<const Basic> Ms =
627 mul(Fwy, add(div(contact_patch_length_, integer(3)), caster_));
628 VLOG(1);
629 VLOG(1) << "Ms " << Ms->__str__();
630 VLOG(1);
631 VLOG(1) << "Fwx " << Fwx->__str__();
632 VLOG(1);
633 VLOG(1) << "Fwy " << Fwy->__str__();
634
635 // alphas = ...
636 RCP<const Basic> lhms =
637 mul(add(neg(wb_), mul(add(rs_, rp_), sub(integer(1), div(rb1_, rp_)))),
justinT21942892b2024-07-02 22:33:50 -0700638 mul(div(r_w_, rb2_), neg(Fwx)));
justinT21446e4f62024-06-16 22:36:10 -0700639 RCP<const Basic> lhs = add(add(Ms, div(mul(Jsm_, result.Is), Gs_)), lhms);
640 RCP<const Basic> rhs = add(Jsm_, div(div(Js_, Gs_), Gs_));
641 RCP<const Basic> accel_steer_eqn = simplify(div(lhs, rhs));
642
643 VLOG(1);
644 VLOG(1) << result.alphas->__str__() << " = " << accel_steer_eqn->__str__();
645
646 lhs = sub(mul(sub(div(add(rp_, rs_), rp_), integer(1)), result.omegas),
647 mul(Gd1_, mul(Gd2_, omegamd)));
648 RCP<const Basic> dplanitary_eqn = sub(mul(Gd3_, lhs), result.omegad);
649
650 lhs = sub(mul(sub(div(add(rp_, rs_), rp_), integer(1)), result.alphas),
651 mul(Gd1_, mul(Gd2_, alphamd)));
652 RCP<const Basic> ddplanitary_eqn = sub(mul(Gd3_, lhs), result.alphad);
653
654 RCP<const Basic> drive_eqn = sub(
655 add(mul(neg(Jdm_), div(alphamd, Gd_)), mul(Ktd_, div(result.Id, Gd_))),
justinT21942892b2024-07-02 22:33:50 -0700656 mul(neg(Fwx), r_w_));
justinT21446e4f62024-06-16 22:36:10 -0700657
658 VLOG(1) << "drive_eqn: " << drive_eqn->__str__();
659
660 // Substitute in ddplanitary_eqn so we get rid of alphamd
661 map_basic_basic map;
662 RCP<const Set> reals = interval(NegInf, Inf, true, true);
663 RCP<const Set> solve_solution = solve(ddplanitary_eqn, alphamd, reals);
664 map[alphamd] = solve_solution->get_args()[1]->get_args()[0];
665 VLOG(1) << "temp: " << solve_solution->__str__();
666 RCP<const Basic> drive_eqn_subs = drive_eqn->subs(map);
667
668 map.clear();
669 map[result.alphas] = accel_steer_eqn;
670 RCP<const Basic> drive_eqn_subs2 = drive_eqn_subs->subs(map);
671 RCP<const Basic> drive_eqn_subs3 = simplify(drive_eqn_subs2);
672 VLOG(1) << "drive_eqn simplified: " << drive_eqn_subs3->__str__();
673
674 solve_solution = solve(drive_eqn_subs3, result.alphad, reals);
675
676 RCP<const Basic> drive_accel =
677 simplify(solve_solution->get_args()[1]->get_args()[0]);
678 VLOG(1) << "drive_accel: " << drive_accel->__str__();
679
680 DenseMatrix mat_output = DenseMatrix(2, 1);
681 mul_dense_dense(R(add(theta_, result.thetas)),
682 DenseMatrix(2, 1, {Fwx, Fwy}), mat_output);
683
684 // Comput the resulting force from the module.
685 DenseMatrix F = mat_output;
686
687 RCP<const Basic> torque = simplify(force_cross(mounting_location, F));
688 result.accel = DenseMatrix(2, 1);
689 mul_dense_scalar(F, pow(m_, minus_one), result.accel);
690 result.angular_accel = div(torque, J_);
691 VLOG(1);
692 VLOG(1) << "angular_accel = " << result.angular_accel->__str__();
693
694 VLOG(1);
695 VLOG(1) << "accel(0, 0) = " << result.accel.get(0, 0)->__str__();
696 VLOG(1);
697 VLOG(1) << "accel(1, 0) = " << result.accel.get(1, 0)->__str__();
698
699 result.alphad_eqn = drive_accel;
700 result.alphas_eqn = accel_steer_eqn;
701 return result;
702 }
703
704 DenseMatrix R(const RCP<const Basic> theta) {
705 return DenseMatrix(2, 2,
706 {cos(theta), neg(sin(theta)), sin(theta), cos(theta)});
707 }
708
709 DenseMatrix angle_cross(DenseMatrix a, RCP<const Basic> b) {
710 return DenseMatrix(2, 1, {mul(a.get(1, 0), b), mul(neg(a.get(0, 0)), b)});
711 }
712
713 RCP<const Basic> force_cross(DenseMatrix r, DenseMatrix f) {
714 return sub(mul(r.get(0, 0), f.get(1, 0)), mul(r.get(1, 0), f.get(0, 0)));
715 }
716
717 // z represents the number of teeth per gear, theta is the angle between
718 // shafts(in degrees), D_02 is the pitch diameter of gear 2 and b_2 is the
719 // length of the tooth of gear 2
720 // returns std::pair(r_01, r_02)
721 std::pair<double, double> GetBevelPitchRadius(double z1, double z2,
722 double theta, double D_02,
723 double b_2) {
724 double gamma_1 = std::atan2(z1, z2);
725 double gamma_2 = theta / 180.0 * std::numbers::pi - gamma_1;
726 double R_m = D_02 / 2 / std::sin(gamma_2) - b_2 / 2;
727 return std::pair(R_m * std::cos(gamma_2), R_m * std::sin(gamma_2));
728 }
729
730 Motor drive_motor_;
731 Motor steer_motor_;
732
733 RCP<const Basic> Cx_;
734 RCP<const Basic> Cy_;
735 RCP<const Basic> r_w_;
736 RCP<const Basic> m_;
737 RCP<const Basic> J_;
738 RCP<const Basic> Gd1_;
739 RCP<const Basic> rs_;
740 RCP<const Basic> rp_;
741 RCP<const Basic> Gd2_;
742 RCP<const Basic> rb1_;
743 RCP<const Basic> rb2_;
744 RCP<const Basic> Gd3_;
745 RCP<const Basic> Gd_;
746 RCP<const Basic> Js_;
747 RCP<const Basic> Gs_;
748 RCP<const Basic> wb_;
749 RCP<const Basic> Jdm_;
750 RCP<const Basic> Jsm_;
751 RCP<const Basic> Kts_;
752 RCP<const Basic> Ktd_;
753 RCP<const Basic> robot_width_;
754 RCP<const Basic> caster_;
755 RCP<const Basic> contact_patch_length_;
756 RCP<const Basic> x_;
757 RCP<const Basic> y_;
758 RCP<const Basic> theta_;
759 RCP<const Basic> vx_;
760 RCP<const Basic> vy_;
761 RCP<const Basic> omega_;
762 RCP<const Basic> ax_;
763 RCP<const Basic> ay_;
764 RCP<const Basic> atheta_;
765
766 std::array<Module, kNumModules> modules_;
767
768 DenseMatrix accel_;
769 RCP<const Basic> angular_accel_;
770};
771
772} // namespace frc971::control_loops::swerve
773
774int main(int argc, char **argv) {
775 aos::InitGoogle(&argc, &argv);
776
777 frc971::control_loops::swerve::SwerveSimulation sim;
778
Austin Schuh0f881092024-06-28 15:36:48 -0700779 if (!FLAGS_cc_output_path.empty() && !FLAGS_h_output_path.empty()) {
justinT21446e4f62024-06-16 22:36:10 -0700780 sim.Write(FLAGS_cc_output_path, FLAGS_h_output_path);
Austin Schuh0f881092024-06-28 15:36:48 -0700781 }
782 if (!FLAGS_py_output_path.empty()) {
justinT21942892b2024-07-02 22:33:50 -0700783 sim.WritePy(FLAGS_py_output_path);
justinT21446e4f62024-06-16 22:36:10 -0700784 }
Austin Schuh0f881092024-06-28 15:36:48 -0700785 if (!FLAGS_casadi_py_output_path.empty()) {
786 sim.WriteCasadi(FLAGS_casadi_py_output_path);
787 }
justinT21446e4f62024-06-16 22:36:10 -0700788
789 return 0;
790}