Add test comparing C++ and Python dynamics codegen

This also adds C++ enums defining the states for the physics (note that
it needs to use regular C enums to not require a static_cast to
convert to integers).

It only passes in trivial scenarios currently....

Change-Id: I97e051c896500715c1252b17723f049981f99f1d
Signed-off-by: James Kuszmaul <jabukuszmaul+collab@gmail.com>
diff --git a/frc971/control_loops/swerve/BUILD b/frc971/control_loops/swerve/BUILD
index 26828e4..2d06837 100644
--- a/frc971/control_loops/swerve/BUILD
+++ b/frc971/control_loops/swerve/BUILD
@@ -214,6 +214,7 @@
     srcs = [
         "physics_test.py",
     ],
+    data = [":cpp_dynamics.so"],
     env = {
         "JAX_PLATFORMS": "cpu",
     },
@@ -234,6 +235,7 @@
     srcs = [
         "physics_test.py",
     ],
+    data = [":cpp_dynamics.so"],
     env = {
         "JAX_PLATFORMS": "cuda",
     },
@@ -329,3 +331,18 @@
         "@pip//pygobject",
     ],
 )
+
+cc_binary(
+    name = "cpp_dynamics.so",
+    # Just use the python dynamics directly if you want them; this is just for testing.
+    testonly = True,
+    srcs = ["dynamics_python_bindings.cc"],
+    linkshared = True,
+    target_compatible_with = ["@platforms//os:linux"],
+    deps = [
+        ":eigen_dynamics",
+        "//third_party/python",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+    ],
+)
diff --git a/frc971/control_loops/swerve/dynamics_python_bindings.cc b/frc971/control_loops/swerve/dynamics_python_bindings.cc
new file mode 100644
index 0000000..e1fbe19
--- /dev/null
+++ b/frc971/control_loops/swerve/dynamics_python_bindings.cc
@@ -0,0 +1,103 @@
+#define PY_SSIZE_T_CLEAN
+// Note that Python.h needs to be included before anything else.
+#include <Python.h>
+
+#include <iostream>
+#include <optional>
+
+#include "frc971/control_loops/swerve/dynamics.h"
+
+namespace frc971::control_loops::swerve {
+namespace {
+template <int N>
+std::optional<Eigen::Matrix<double, N, 1>> ToEigen(PyObject *list) {
+  Eigen::Matrix<double, N, 1> result;
+  for (size_t index = 0; index < N; ++index) {
+    PyObject *element = PyList_GetItem(list, index);
+    if (!PyFloat_Check(element)) {
+      PyErr_SetString(PyExc_ValueError,
+                      "Input lists should be lists of floats.");
+      return std::nullopt;
+    }
+    result(index) = PyFloat_AsDouble(element);
+  }
+  return result;
+}
+PyObject *swerve_dynamics(PyObject * /*self*/, PyObject *args) {
+  PyObject *X;
+  PyObject *U;
+  if (!PyArg_ParseTuple(args, "OO", &X, &U)) {
+    PyErr_SetString(PyExc_ValueError, "Input arguments should be two lists.");
+    return nullptr;
+  }
+
+  if (!PyList_Check(X)) {
+    PyErr_SetString(PyExc_ValueError, "X should be a list.");
+    return nullptr;
+  }
+  if (!PyList_Check(U)) {
+    PyErr_SetString(PyExc_ValueError, "U should be a list.");
+    return nullptr;
+  }
+
+  if (PyList_Size(X) != kNumFullDynamicsStates) {
+    PyErr_SetString(PyExc_ValueError,
+                    "X should have kNumFullDynamicsStates elements.");
+    return nullptr;
+  }
+
+  if (PyList_Size(U) != kNumInputs) {
+    PyErr_SetString(PyExc_ValueError, "U should have kNumInputs elements.");
+    return nullptr;
+  }
+
+  std::optional<Eigen::Matrix<double, kNumFullDynamicsStates, 1>> X_eig =
+      ToEigen<kNumFullDynamicsStates>(X);
+
+  if (!X_eig.has_value()) {
+    return nullptr;
+  }
+
+  std::optional<Eigen::Matrix<double, kNumInputs, 1>> U_eig =
+      ToEigen<kNumInputs>(X);
+
+  if (!U_eig.has_value()) {
+    return nullptr;
+  }
+
+  Eigen::Matrix<double, kNumFullDynamicsStates, 1> Xdot =
+      SwervePhysics(X_eig.value(), U_eig.value());
+
+  PyObject *result = PyList_New(kNumFullDynamicsStates);
+  for (size_t index = 0; index < kNumFullDynamicsStates; ++index) {
+    if (PyList_SetItem(result, index, PyFloat_FromDouble(Xdot(index))) != 0) {
+      return nullptr;
+    }
+  }
+
+  return result;
+}
+
+static PyMethodDef methods[] = {
+    {"swerve_dynamics", swerve_dynamics, METH_VARARGS,
+     "Xdot = swerve_dynamics(X, U), all types are lists."},
+    {NULL, NULL, 0, NULL}  // Sentinel
+};
+
+static PyModuleDef cpp_dynamics_module = {
+    .m_base = PyModuleDef_HEAD_INIT,
+    .m_name = "cpp_dynamics",
+    .m_doc =
+        "Wraps the generated C++ dynamics in order to support convenient "
+        "testing.",
+    .m_size = -1,
+    .m_methods = methods,
+};
+
+PyObject *InitModule() { return PyModule_Create(&cpp_dynamics_module); }
+}  // namespace
+}  // namespace frc971::control_loops::swerve
+
+PyMODINIT_FUNC PyInit_cpp_dynamics(void) {
+  return frc971::control_loops::swerve::InitModule();
+}
diff --git a/frc971/control_loops/swerve/generate_physics.cc b/frc971/control_loops/swerve/generate_physics.cc
index 77dfbb0..af058f6 100644
--- a/frc971/control_loops/swerve/generate_physics.cc
+++ b/frc971/control_loops/swerve/generate_physics.cc
@@ -283,18 +283,107 @@
     result_h.emplace_back("");
     result_h.emplace_back("namespace frc971::control_loops::swerve {");
     result_h.emplace_back("");
+    result_h.emplace_back("struct FullDynamicsStates {");
+    result_h.emplace_back("enum States {");
+    result_h.emplace_back("  kThetas0 = 0,");
+    result_h.emplace_back("  kThetad0 = 1,");
+    result_h.emplace_back("  kOmegas0 = 2,");
+    result_h.emplace_back("  kOmegad0 = 3,");
+    result_h.emplace_back("  kThetas1 = 4,");
+    result_h.emplace_back("  kThetad1 = 5,");
+    result_h.emplace_back("  kOmegas1 = 6,");
+    result_h.emplace_back("  kOmegad1 = 7,");
+    result_h.emplace_back("  kThetas2 = 8,");
+    result_h.emplace_back("  kThetad2 = 9,");
+    result_h.emplace_back("  kOmegas2 = 10,");
+    result_h.emplace_back("  kOmegad2 = 11,");
+    result_h.emplace_back("  kThetas3 = 12,");
+    result_h.emplace_back("  kThetad3 = 13,");
+    result_h.emplace_back("  kOmegas3 = 14,");
+    result_h.emplace_back("  kOmegad3 = 15,");
+    result_h.emplace_back("  kX = 16,");
+    result_h.emplace_back("  kY = 17,");
+    result_h.emplace_back("  kTheta = 18,");
+    result_h.emplace_back("  kVx = 19,");
+    result_h.emplace_back("  kVy = 20,");
+    result_h.emplace_back("  kOmega = 21,");
+    result_h.emplace_back("  kFx = 22,");
+    result_h.emplace_back("  kFy = 23,");
+    result_h.emplace_back("  kMoment = 24,");
+    result_h.emplace_back("  kNumStates");
+    result_h.emplace_back("};");
+    result_h.emplace_back("};");
+    result_h.emplace_back(
+        "inline constexpr size_t kNumFullDynamicsStates = "
+        "static_cast<size_t>(FullDynamicsStates::kNumStates);");
+    result_h.emplace_back("struct VelocityStates {");
+    result_h.emplace_back("enum States {");
+    result_h.emplace_back("  kThetas0 = 0,");
+    result_h.emplace_back("  kOmegas0 = 1,");
+    result_h.emplace_back("  kThetas1 = 2,");
+    result_h.emplace_back("  kOmegas1 = 3,");
+    result_h.emplace_back("  kThetas2 = 4,");
+    result_h.emplace_back("  kOmegas2 = 5,");
+    result_h.emplace_back("  kThetas3 = 6,");
+    result_h.emplace_back("  kOmegas3 = 7,");
+    result_h.emplace_back("  kTheta = 8,");
+    result_h.emplace_back("  kVx = 9,");
+    result_h.emplace_back("  kVy = 10,");
+    result_h.emplace_back("  kOmega = 11,");
+    result_h.emplace_back("  kNumStates");
+    result_h.emplace_back("};");
+    result_h.emplace_back("};");
+    result_h.emplace_back(
+        "inline constexpr size_t kNumVelocityStates = "
+        "static_cast<size_t>(VelocityStates::kNumStates);");
+    result_h.emplace_back("struct Inputs {");
+    result_h.emplace_back("enum States {");
+    result_h.emplace_back("  kIs0 = 0,");
+    result_h.emplace_back("  kId0 = 1,");
+    result_h.emplace_back("  kIs1 = 2,");
+    result_h.emplace_back("  kId1 = 3,");
+    result_h.emplace_back("  kIs2 = 4,");
+    result_h.emplace_back("  kId2 = 5,");
+    result_h.emplace_back("  kIs3 = 6,");
+    result_h.emplace_back("  kId3 = 7,");
+    result_h.emplace_back("  kNumInputs = 8");
+    result_h.emplace_back("};");
+    result_h.emplace_back("};");
+    result_h.emplace_back(
+        "inline constexpr size_t kNumInputs = "
+        "static_cast<size_t>(Inputs::kNumInputs);");
+    result_h.emplace_back("");
     result_h.emplace_back("// Returns the derivative of our state vector");
-    result_h.emplace_back("// [thetas0, thetad0, omegas0, omegad0,");
-    result_h.emplace_back("//  thetas1, thetad1, omegas1, omegad1,");
-    result_h.emplace_back("//  thetas2, thetad2, omegas2, omegad2,");
-    result_h.emplace_back("//  thetas3, thetad3, omegas3, omegad3,");
-    result_h.emplace_back("//  x, y, theta, vx, vy, omega,");
-    result_h.emplace_back("//  Fx, Fy, Moment]");
-    result_h.emplace_back("Eigen::Matrix<double, 25, 1> SwervePhysics(");
     result_h.emplace_back(
-        "    Eigen::Map<const Eigen::Matrix<double, 25, 1>> X,");
+        "Eigen::Matrix<double, kNumFullDynamicsStates, 1> SwervePhysics(");
     result_h.emplace_back(
-        "    Eigen::Map<const Eigen::Matrix<double, 8, 1>> U);");
+        "    Eigen::Ref<const Eigen::Matrix<double, kNumFullDynamicsStates, "
+        "1>> X,");
+    result_h.emplace_back(
+        "    Eigen::Ref<const Eigen::Matrix<double, kNumInputs, 1>> U);");
+    result_h.emplace_back("");
+    result_h.emplace_back(
+        "Eigen::Matrix<double, kNumVelocityStates, 1> ToVelocityState(");
+    result_h.emplace_back(
+        "    Eigen::Ref<const Eigen::Matrix<double, kNumFullDynamicsStates, "
+        "1>> X);");
+    result_h.emplace_back("");
+    result_h.emplace_back(
+        "Eigen::Matrix<double, kNumFullDynamicsStates, 1> FromVelocityState(");
+    result_h.emplace_back(
+        "    Eigen::Ref<const Eigen::Matrix<double, kNumVelocityStates, 1>> "
+        "X);");
+    result_h.emplace_back("");
+    result_h.emplace_back(
+        "inline Eigen::Matrix<double, kNumVelocityStates, 1> VelocityPhysics(");
+    result_h.emplace_back(
+        "    Eigen::Ref<const Eigen::Matrix<double, kNumVelocityStates, 1>> "
+        "X,");
+    result_h.emplace_back(
+        "    Eigen::Ref<const Eigen::Matrix<double, kNumInputs, 1>> U) {");
+    result_h.emplace_back(
+        "  return ToVelocityState(SwervePhysics(FromVelocityState(X), U));");
+    result_h.emplace_back("}");
     result_h.emplace_back("");
     result_h.emplace_back("}  // namespace frc971::control_loops::swerve");
     result_h.emplace_back("");
@@ -308,12 +397,49 @@
     result_cc.emplace_back("");
     result_cc.emplace_back("namespace frc971::control_loops::swerve {");
     result_cc.emplace_back("");
-    result_cc.emplace_back("Eigen::Matrix<double, 25, 1> SwervePhysics(");
     result_cc.emplace_back(
-        "    Eigen::Map<const Eigen::Matrix<double, 25, 1>> X,");
+        "Eigen::Matrix<double, kNumVelocityStates, 1> ToVelocityState(");
     result_cc.emplace_back(
-        "    Eigen::Map<const Eigen::Matrix<double, 8, 1>> U) {");
-    result_cc.emplace_back("  Eigen::Matrix<double, 25, 1> result;");
+        "    Eigen::Ref<const Eigen::Matrix<double, kNumFullDynamicsStates, "
+        "1>> X) {");
+    result_cc.emplace_back(
+        "    Eigen::Matrix<double, kNumVelocityStates, 1> velocity;");
+    const std::vector<std::string_view> velocity_states = {
+        "kThetas0", "kOmegas0", "kThetas1", "kOmegas1", "kThetas2", "kOmegas2",
+        "kThetas3", "kOmegas3", "kTheta",   "kVx",      "kVy",      "kOmega"};
+    for (const std::string_view velocity_state : velocity_states) {
+      result_cc.emplace_back(absl::StrFormat(
+          "  velocity(VelocityStates::%s) = X(FullDynamicsStates::%s);",
+          velocity_state, velocity_state));
+    }
+    result_cc.emplace_back("  return velocity;");
+    result_cc.emplace_back("}");
+    result_cc.emplace_back("");
+    result_cc.emplace_back(
+        "Eigen::Matrix<double, kNumFullDynamicsStates, 1> FromVelocityState(");
+    result_cc.emplace_back(
+        "    Eigen::Ref<const Eigen::Matrix<double, kNumVelocityStates, 1>> X) "
+        "{");
+    result_cc.emplace_back(
+        "    Eigen::Matrix<double, kNumFullDynamicsStates, 1> full;");
+    result_cc.emplace_back("    full.setZero();");
+    for (const std::string_view velocity_state : velocity_states) {
+      result_cc.emplace_back(absl::StrFormat(
+          "  full(FullDynamicsStates::%s) = X(VelocityStates::%s);",
+          velocity_state, velocity_state));
+    }
+    result_cc.emplace_back("  return full;");
+    result_cc.emplace_back("}");
+    result_cc.emplace_back("");
+    result_cc.emplace_back(
+        "Eigen::Matrix<double, kNumFullDynamicsStates, 1> SwervePhysics(");
+    result_cc.emplace_back(
+        "    Eigen::Ref<const Eigen::Matrix<double, kNumFullDynamicsStates, "
+        "1>> X,");
+    result_cc.emplace_back(
+        "    Eigen::Ref<const Eigen::Matrix<double, kNumInputs, 1>> U) {");
+    result_cc.emplace_back(
+        "  Eigen::Matrix<double, kNumFullDynamicsStates, 1> result;");
 
     // Start by writing out variables matching each of the symbol names we use
     // so we don't have to modify the computed equations too much.
diff --git a/frc971/control_loops/swerve/physics_test.py b/frc971/control_loops/swerve/physics_test.py
index 6e7ddf4..31b96c9 100644
--- a/frc971/control_loops/swerve/physics_test.py
+++ b/frc971/control_loops/swerve/physics_test.py
@@ -19,6 +19,7 @@
 from frc971.control_loops.swerve import nocaster_dynamics
 from frc971.control_loops.swerve import physics_test_utils as utils
 from frc971.control_loops.swerve import jax_dynamics
+from frc971.control_loops.swerve.cpp_dynamics import swerve_dynamics as cpp_dynamics
 
 
 class TestSwervePhysics(unittest.TestCase):
@@ -705,6 +706,29 @@
         self.assertAlmostEquals(Xdot[dynamics.STATE_OMEGA, 0],
                                 Xdot_rot[dynamics.STATE_OMEGA, 0])
 
+    def test_cpp_consistency(self):
+        """Tests that the C++ physics are consistent with the Python physics."""
+        # TODO(james): Currently the physics only match at X = 0 and U = 0.
+        # Fix this.
+        # Maybe due to different atan2 implementations?
+        # TODO(james): Fold this into the general comparisons for JAX versus
+        # casadi once the physics actually match.
+        for current in [0]:
+            print(f"Current: {current}")
+            steer_I = numpy.zeros((8, 1)) + current
+            for state_values in [0.0]:
+                print(f"States all set to: {state_values}")
+                X = numpy.zeros((dynamics.NUM_STATES, 1)) + state_values
+                Xdot_py = self.swerve_full_dynamics(X,
+                                                    steer_I,
+                                                    skip_compare=True)
+                Xdot_cpp = numpy.array(
+                    cpp_dynamics(X.flatten().tolist(),
+                                 steer_I.flatten().tolist())).reshape((25, 1))
+                for index in range(dynamics.NUM_STATES):
+                    self.assertAlmostEqual(Xdot_py[index, 0], Xdot_cpp[index,
+                                                                       0])
+
 
 if __name__ == "__main__":
     unittest.main()