Fix some cases of divergence in the physics model

Signed-off-by: justinT21 <jjturcot@gmail.com>
Change-Id: Ib6e322af67ef66567b528802f145ba263554af88
diff --git a/frc971/control_loops/swerve/BUILD b/frc971/control_loops/swerve/BUILD
index 4fd73d2..052b287 100644
--- a/frc971/control_loops/swerve/BUILD
+++ b/frc971/control_loops/swerve/BUILD
@@ -129,11 +129,13 @@
     outs = [
         "dynamics.cc",
         "dynamics.h",
+        "dynamics.py",
     ],
     args = [
         "--output_base=$(BINDIR)/",
         "--cc_output_path=$(location :dynamics.cc)",
         "--h_output_path=$(location :dynamics.h)",
+        "--py_output_path=$(location :dynamics.py)",
     ],
     tool = ":generate_physics",
 )
@@ -146,3 +148,18 @@
         "@org_tuxfamily_eigen//:eigen",
     ],
 )
+
+py_binary(
+    name = "dynamics_sim",
+    srcs = [
+        "dynamics.py",
+        "dynamics_sim.py",
+    ],
+    deps = [
+        "//frc971/control_loops/python:controls",
+        "@pip//matplotlib",
+        "@pip//numpy",
+        "@pip//pygobject",
+        "@pip//scipy",
+    ],
+)
diff --git a/frc971/control_loops/swerve/dynamics_sim.py b/frc971/control_loops/swerve/dynamics_sim.py
new file mode 100644
index 0000000..3c597a0
--- /dev/null
+++ b/frc971/control_loops/swerve/dynamics_sim.py
@@ -0,0 +1,83 @@
+#!/usr/bin/python3
+
+import numpy
+import math
+import scipy.integrate
+
+from matplotlib import pylab
+import sys
+import gflags
+import glog
+
+from frc971.control_loops.swerve.dynamics import swerve_physics
+
+FLAGS = gflags.FLAGS
+
+
+def u_func(X):
+    result = numpy.zeros([8, 1])
+    # result[1, 0] = 80.0
+    # result[3, 0] = 80.0
+    # result[5, 0] = -80.0
+    # result[7, 0] = -80.0
+    return result
+
+
+def main(argv):
+    x_initial = numpy.zeros([25, 1])
+    # x_initial[0] = -math.pi / 2.0
+    x_initial[3] = 1.0 / (2.0 * 0.0254)
+    # x_initial[4] = math.pi / 2.0
+    x_initial[7] = 1.0 / (2.0 * 0.0254)
+    # x_initial[8] = -math.pi / 2.0
+    x_initial[11] = 1.0 / (2.0 * 0.0254)
+    # x_initial[12] = math.pi / 2.0
+    x_initial[15] = 1.0 / (2.0 * 0.0254)
+    x_initial[19] = 2.0
+    result = scipy.integrate.solve_ivp(swerve_physics, (0, 2.0),
+                                       x_initial.reshape(25, ),
+                                       max_step=0.01,
+                                       args=(u_func, ))
+
+    cm = pylab.get_cmap('gist_rainbow')
+    fig = pylab.figure()
+    ax = fig.add_subplot(111)
+    ax.set_prop_cycle(color=[cm(1. * i / 25) for i in range(25)])
+    ax.plot(result.t, result.y[0, :], label="thetas0", linewidth=7.0)
+    ax.plot(result.t, result.y[1, :], label="thetad0", linewidth=7.0)
+    ax.plot(result.t, result.y[2, :], label="omegas0", linewidth=7.0)
+    ax.plot(result.t, result.y[3, :], label="omegad0", linewidth=7.0)
+    ax.plot(result.t, result.y[4, :], label="thetas1", linewidth=7.0)
+    ax.plot(result.t, result.y[5, :], label="thetad1", linewidth=7.0)
+    ax.plot(result.t, result.y[6, :], label="omegas1", linewidth=7.0)
+    ax.plot(result.t, result.y[7, :], label="omegad1", linewidth=7.0)
+    ax.plot(result.t, result.y[8, :], label="thetas2", linewidth=7.0)
+    ax.plot(result.t, result.y[9, :], label="thetad2", linewidth=7.0)
+    ax.plot(result.t, result.y[10, :], label="omegas2", linewidth=7.0)
+    ax.plot(result.t, result.y[11, :], label="omegad2", linewidth=7.0)
+    ax.plot(result.t, result.y[12, :], label="thetas3", linewidth=7.0)
+    ax.plot(result.t, result.y[13, :], label="thetad3", linewidth=7.0)
+    ax.plot(result.t, result.y[14, :], label="omegas3", linewidth=7.0)
+    ax.plot(result.t, result.y[15, :], label="omegad3", linewidth=7.0)
+    ax.plot(result.t, result.y[16, :], label="x", linewidth=7.0)
+    ax.plot(result.t, result.y[17, :], label="y", linewidth=7.0)
+    ax.plot(result.t, result.y[18, :], label="theta", linewidth=7.0)
+    ax.plot(result.t, result.y[19, :], label="vx", linewidth=7.0)
+    ax.plot(result.t, result.y[20, :], label="vy", linewidth=7.0)
+    ax.plot(result.t, result.y[21, :], label="omega", linewidth=7.0)
+    ax.plot(result.t, result.y[22, :], label="Fx", linewidth=7.0)
+    ax.plot(result.t, result.y[23, :], label="Fy", linewidth=7.0)
+    ax.plot(result.t, result.y[24, :], label="Moment", linewidth=7.0)
+    numpy.set_printoptions(threshold=numpy.inf)
+    print(result.t)
+    print(result.y)
+    pylab.legend()
+    pylab.show()
+
+    return 0
+
+
+if __name__ == '__main__':
+    argv = FLAGS(sys.argv)
+    glog.init()
+    sys.exit(main(argv))
diff --git a/frc971/control_loops/swerve/generate_physics.cc b/frc971/control_loops/swerve/generate_physics.cc
index 52f82f0..d8c1321 100644
--- a/frc971/control_loops/swerve/generate_physics.cc
+++ b/frc971/control_loops/swerve/generate_physics.cc
@@ -27,9 +27,11 @@
               "Path to strip off the front of the output paths.");
 DEFINE_string(cc_output_path, "", "Path to write generated header code to");
 DEFINE_string(h_output_path, "", "Path to write generated cc code to");
+DEFINE_string(py_output_path, "", "Path to write generated py code to");
 
 DEFINE_bool(symbolic, false, "If true, write everything out symbolically.");
 
+using SymEngine::abs;
 using SymEngine::add;
 using SymEngine::atan2;
 using SymEngine::Basic;
@@ -212,6 +214,108 @@
     VLOG(1) << "angular_accel = " << ccode(*angular_accel_);
   }
 
+  // Writes the physics out to the provided .py path.
+  void WritePy(std::string_view py_path) {
+    std::vector<std::string> result_py;
+
+    result_py.emplace_back("#!/usr/bin/python3");
+    result_py.emplace_back("");
+    result_py.emplace_back("import numpy");
+    result_py.emplace_back("import math");
+    result_py.emplace_back("from math import sin, cos, fabs");
+    result_py.emplace_back("");
+
+    result_py.emplace_back("def atan2(y, x):");
+    result_py.emplace_back("    if x < 0:");
+    result_py.emplace_back("        return -math.atan2(y, x)");
+    result_py.emplace_back("    else:");
+    result_py.emplace_back("        return math.atan2(y, x)");
+
+    result_py.emplace_back("def swerve_physics(t, X, U_func):");
+    // result_py.emplace_back("    print(X)");
+    result_py.emplace_back("    result = numpy.empty([25, 1])");
+    result_py.emplace_back("    X = X.reshape(25, 1)");
+    result_py.emplace_back("    U = U_func(X)");
+    result_py.emplace_back("");
+
+    // Start by writing out variables matching each of the symbol names we use
+    // so we don't have to modify the computed equations too much.
+    for (size_t m = 0; m < kNumModules; ++m) {
+      result_py.emplace_back(
+          absl::Substitute("    thetas$0 = X[$1, 0]", m, m * 4));
+      result_py.emplace_back(
+          absl::Substitute("    omegas$0 = X[$1, 0]", m, m * 4 + 2));
+      result_py.emplace_back(
+          absl::Substitute("    omegad$0 = X[$1, 0]", m, m * 4 + 3));
+    }
+
+    result_py.emplace_back(
+        absl::Substitute("    theta = X[$0, 0]", kNumModules * 4 + 2));
+    result_py.emplace_back(
+        absl::Substitute("    vx = X[$0, 0]", kNumModules * 4 + 3));
+    result_py.emplace_back(
+        absl::Substitute("    vy = X[$0, 0]", kNumModules * 4 + 4));
+    result_py.emplace_back(
+        absl::Substitute("    omega = X[$0, 0]", kNumModules * 4 + 5));
+
+    result_py.emplace_back(
+        absl::Substitute("    fx = X[$0, 0]", kNumModules * 4 + 6));
+    result_py.emplace_back(
+        absl::Substitute("    fy = X[$0, 0]", kNumModules * 4 + 7));
+    result_py.emplace_back(
+        absl::Substitute("    moment = X[$0, 0]", kNumModules * 4 + 8));
+
+    // Now do the same for the inputs.
+    for (size_t m = 0; m < kNumModules; ++m) {
+      result_py.emplace_back(absl::Substitute("    Is$0 = U[$1, 0]", m, m * 2));
+      result_py.emplace_back(
+          absl::Substitute("    Id$0 = U[$1, 0]", m, m * 2 + 1));
+    }
+
+    result_py.emplace_back("");
+
+    // And then write out the derivative of each state.
+    for (size_t m = 0; m < kNumModules; ++m) {
+      result_py.emplace_back(
+          absl::Substitute("    result[$0, 0] = omegas$1", m * 4, m));
+      result_py.emplace_back(
+          absl::Substitute("    result[$0, 0] = omegad$1", m * 4 + 1, m));
+
+      result_py.emplace_back(absl::Substitute(
+          "    result[$0, 0] = $1", m * 4 + 2, ccode(*modules_[m].alphas_eqn)));
+      result_py.emplace_back(absl::Substitute(
+          "    result[$0, 0] = $1", m * 4 + 3, ccode(*modules_[m].alphad_eqn)));
+    }
+
+    result_py.emplace_back(
+        absl::Substitute("    result[$0, 0] = vx", kNumModules * 4));
+    result_py.emplace_back(
+        absl::Substitute("    result[$0, 0] = vy", kNumModules * 4 + 1));
+    result_py.emplace_back(
+        absl::Substitute("    result[$0, 0] = omega", kNumModules * 4 + 2));
+
+    result_py.emplace_back(absl::Substitute("    result[$0, 0] = $1",
+                                            kNumModules * 4 + 3,
+                                            ccode(*accel_.get(0, 0))));
+    result_py.emplace_back(absl::Substitute("    result[$0, 0] = $1",
+                                            kNumModules * 4 + 4,
+                                            ccode(*accel_.get(1, 0))));
+    result_py.emplace_back(absl::Substitute(
+        "    result[$0, 0] = $1", kNumModules * 4 + 5, ccode(*angular_accel_)));
+
+    result_py.emplace_back(
+        absl::Substitute("    result[$0, 0] = 0.0", kNumModules * 4 + 6));
+    result_py.emplace_back(
+        absl::Substitute("    result[$0, 0] = 0.0", kNumModules * 4 + 7));
+    result_py.emplace_back(
+        absl::Substitute("    result[$0, 0] = 0.0", kNumModules * 4 + 8));
+
+    result_py.emplace_back("");
+    result_py.emplace_back("    return result.reshape(25,)\n");
+
+    aos::util::WriteStringToFileOrDie(py_path, absl::StrJoin(result_py, "\n"));
+  }
+
   // Writes the physics out to the provided .cc and .h path.
   void Write(std::string_view cc_path, std::string_view h_path) {
     std::vector<std::string> result_cc;
@@ -370,7 +474,8 @@
     result.alphad = symbol(absl::StrFormat("alphad%u", m));
 
     // Velocity of the module in field coordinates
-    DenseMatrix robot_velocity = DenseMatrix(2, 1, {vx_, vy_});
+    DenseMatrix robot_velocity = DenseMatrix(2, 1);
+    mul_dense_dense(R(theta_), DenseMatrix(2, 1, {vx_, vy_}), robot_velocity);
     VLOG(1) << "robot velocity: " << robot_velocity.__str__();
 
     // Velocity of the contact patch in field coordinates
@@ -398,15 +503,15 @@
     VLOG(1);
     VLOG(1) << "wheel ground velocity: " << wheel_ground_velocity.__str__();
 
-    RCP<const Basic> slip_angle =
-        atan2(wheel_ground_velocity.get(1, 0), wheel_ground_velocity.get(0, 0));
+    RCP<const Basic> slip_angle = neg(atan2(wheel_ground_velocity.get(1, 0),
+                                            wheel_ground_velocity.get(0, 0)));
 
     VLOG(1);
     VLOG(1) << "slip angle: " << slip_angle->__str__();
 
     RCP<const Basic> slip_ratio =
         div(sub(mul(r_w_, result.omegad), wheel_ground_velocity.get(0, 0)),
-            wheel_ground_velocity.get(0, 0));
+            abs(wheel_ground_velocity.get(0, 0)));
     VLOG(1);
     VLOG(1) << "Slip ratio " << slip_ratio->__str__();
 
@@ -425,7 +530,7 @@
     // alphas = ...
     RCP<const Basic> lhms =
         mul(add(neg(wb_), mul(add(rs_, rp_), sub(integer(1), div(rb1_, rp_)))),
-            mul(div(r_w_, rb2_), Fwx));
+            mul(div(r_w_, rb2_), neg(Fwx)));
     RCP<const Basic> lhs = add(add(Ms, div(mul(Jsm_, result.Is), Gs_)), lhms);
     RCP<const Basic> rhs = add(Jsm_, div(div(Js_, Gs_), Gs_));
     RCP<const Basic> accel_steer_eqn = simplify(div(lhs, rhs));
@@ -443,7 +548,7 @@
 
     RCP<const Basic> drive_eqn = sub(
         add(mul(neg(Jdm_), div(alphamd, Gd_)), mul(Ktd_, div(result.Id, Gd_))),
-        mul(Fwx, r_w_));
+        mul(neg(Fwx), r_w_));
 
     VLOG(1) << "drive_eqn: " << drive_eqn->__str__();
 
@@ -566,8 +671,10 @@
 
   frc971::control_loops::swerve::SwerveSimulation sim;
 
-  if (!FLAGS_cc_output_path.empty() && !FLAGS_h_output_path.empty()) {
+  if (!FLAGS_cc_output_path.empty() && !FLAGS_h_output_path.empty() &&
+      !FLAGS_py_output_path.empty()) {
     sim.Write(FLAGS_cc_output_path, FLAGS_h_output_path);
+    sim.WritePy(FLAGS_py_output_path);
   }
 
   return 0;