Merge "Add flatbuffers to y2024_bot3" into main
diff --git a/WORKSPACE b/WORKSPACE
index ad59314..6fd635e 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -702,14 +702,16 @@
 cc_library(
     name = 'api-cpp',
     visibility = ['//visibility:public'],
-    hdrs = glob(['ctre/phoenix6/**/*.hpp']),
+    hdrs = glob(['ctre/phoenix6/**/*.hpp', 'ctre/unit/**/*.h']),
     includes = ["."],
-    deps = ["@//third_party/allwpilib/wpimath"],
+    deps = ["@//third_party/allwpilib/wpimath",
+            "@ctre_phoenix6_tools_headers//:tools",
+            ],
 )
 """,
-    sha256 = "1b9295ac868d619d0263f285b372a22297798e45c8fdeb83ca99154a097c5c98",
+    sha256 = "67fdd9a4bf275c69666bbc8bf38312eea8a85bf9fda3901d02af3a18136ffb3e",
     urls = [
-        "https://maven.ctr-electronics.com/release/com/ctre/phoenix6/api-cpp/24.1.0/api-cpp-24.1.0-headers.zip",
+        "https://maven.ctr-electronics.com/release/com/ctre/phoenix6/api-cpp/24.50.0-alpha-2/api-cpp-24.50.0-alpha-2-headers.zip",
     ],
 )
 
@@ -731,9 +733,9 @@
     target_compatible_with = ['@//tools/platforms/hardware:roborio'],
 )
 """,
-    sha256 = "51f4921f8995e3e80ba347a90f6fa5c0cb7d0e438923a1a736554f263e2d5b8e",
+    sha256 = "00da14f437cfeb2c344674dee59fe70477a3a0d612eb93a1441188dc94a5136f",
     urls = [
-        "https://maven.ctr-electronics.com/release/com/ctre/phoenix6/api-cpp/24.1.0/api-cpp-24.1.0-linuxathena.zip",
+        "https://maven.ctr-electronics.com/release/com/ctre/phoenix6/api-cpp/24.50.0-alpha-2/api-cpp-24.50.0-alpha-2-linuxathena.zip",
     ],
 )
 
@@ -743,12 +745,12 @@
 cc_library(
     name = 'tools',
     visibility = ['//visibility:public'],
-    hdrs = glob(['ctre/**/*.h', 'ctre/**/*.hpp']),
+    hdrs = glob(['ctre/**/*.h', 'ctre/phoenix/**/*.hpp', 'ctre/phoenix6/**/*.hpp']),
 )
 """,
-    sha256 = "72be3e50205e2546736361e3aba9c0de5d5318b263c195600f9b558c3d92cb9e",
+    sha256 = "77624291ec19a03c9e068347ef800e435782444722793d56c9e39f6108da33a8",
     urls = [
-        "https://maven.ctr-electronics.com/release/com/ctre/phoenix6/tools/24.1.0/tools-24.1.0-headers.zip",
+        "https://maven.ctr-electronics.com/release/com/ctre/phoenix6/tools/24.50.0-alpha-2/tools-24.50.0-alpha-2-headers.zip",
     ],
 )
 
@@ -770,9 +772,9 @@
     target_compatible_with = ['@//tools/platforms/hardware:roborio'],
 )
 """,
-    sha256 = "76d30cb8c0988eb6accab10e77ceae492782edb64c06f7df628b8c161cf0220a",
+    sha256 = "217bcd6aecb71224fd32725f857da58e61a6ea451ca07f85fb55e49f83dbf113",
     urls = [
-        "https://maven.ctr-electronics.com/release/com/ctre/phoenix6/tools/24.1.0/tools-24.1.0-linuxathena.zip",
+        "https://maven.ctr-electronics.com/release/com/ctre/phoenix6/tools/24.50.0-alpha-2/tools-24.50.0-alpha-2-linuxathena.zip",
     ],
 )
 
@@ -782,12 +784,13 @@
 cc_library(
     name = 'api-cpp',
     visibility = ['//visibility:public'],
-    hdrs = glob(['ctre/phoenix/**/*.h']),
+    hdrs = glob(['ctre/phoenix/**/*.h', 'ctre/unit/**/*.h']),
+    includes = ["."]
 )
 """,
-    sha256 = "d2790e2495a59a3b14ab4b0fd03c2e25292d6874168f5642165acb0d1bab2f99",
+    sha256 = "4fba20441ea1d61f8487897682c723a48ee2c2741946eab4062b4248434c0afc",
     urls = [
-        "https://maven.ctr-electronics.com/release/com/ctre/phoenix/api-cpp/5.33.0/api-cpp-5.33.0-headers.zip",
+        "https://maven.ctr-electronics.com/release/com/ctre/phoenix/api-cpp/5.34.0-alpha-1/api-cpp-5.34.0-alpha-1-headers.zip",
     ],
 )
 
@@ -824,9 +827,9 @@
     hdrs = glob(['ctre/phoenix/**/*.h']),
 )
 """,
-    sha256 = "56cb8272d623721560eea360730d2508f4f365e2ac1060aec5fbd546cffc0771",
+    sha256 = "cd39f32341037a7ec1074710044b332db6ca607dfd548b5f951c75f7a8506ab5",
     urls = [
-        "https://maven.ctr-electronics.com/release/com/ctre/phoenix/cci/5.33.0/cci-5.33.0-headers.zip",
+        "https://maven.ctr-electronics.com/release/com/ctre/phoenix/cci/5.34.0-alpha-1/cci-5.34.0-alpha-1-headers.zip",
     ],
 )
 
@@ -867,6 +870,7 @@
     target_compatible_with = ['@platforms//cpu:arm64'],
 )
 
+
 # TODO(max): Use cc_import once they add a defines property.
 # See: https://github.com/bazelbuild/bazel/issues/19753
 cc_library(
@@ -891,9 +895,9 @@
     ],
 )
 """,
-    sha256 = "635b19f019d1283749fb23a95ad9e13ddd9d5013cff4c838303d898e7d9556bd",
+    sha256 = "0f1312f39eacc490fb253198c2d0e61e48ae00eff6a87cfd362358b1ad36a930",
     urls = [
-        "https://software.frc971.org/Build-Dependencies/phoenix6_24.2.0_5.4.2024.tar.gz",
+        "https://software.frc971.org/Build-Dependencies/phoenix6_24.50.0-alpha-2_arm64-2024.10.26.tar.gz",
     ],
 )
 
diff --git a/aos/containers/error_list.h b/aos/containers/error_list.h
index 7ceb1fe..762d2ad 100644
--- a/aos/containers/error_list.h
+++ b/aos/containers/error_list.h
@@ -5,6 +5,7 @@
 
 #include <algorithm>
 
+#include "absl/log/check.h"
 #include "flatbuffers/buffer.h"
 #include "flatbuffers/flatbuffer_builder.h"
 #include "flatbuffers/vector.h"
@@ -128,7 +129,11 @@
       flatbuffers::FlatBufferBuilder *fbb) const {
     return fbb->CreateVector(array_.data(), array_.size());
   }
-};  // namespace aos
+  template <typename StaticBuilder>
+  void ToStaticFlatbuffer(StaticBuilder *vector_builder) const {
+    CHECK(vector_builder->FromData(array_.data(), array_.size()));
+  }
+};
 
 }  // namespace aos
 
diff --git a/aos/events/logging/config_remapper.cc b/aos/events/logging/config_remapper.cc
index 763cbc3..edc2326 100644
--- a/aos/events/logging/config_remapper.cc
+++ b/aos/events/logging/config_remapper.cc
@@ -566,7 +566,9 @@
 
     // Add the schema if it doesn't exist.
     if (schema_map.find(c->type()->string_view()) == schema_map.end()) {
-      CHECK(c->has_schema());
+      if (!c->has_schema()) {
+        LOG(FATAL) << "Could not find schema for " << c->type()->string_view();
+      }
       schema_map.insert(std::make_pair(c->type()->string_view(),
                                        RecursiveCopyFlatBuffer(c->schema())));
     }
diff --git a/frc971/control_loops/dlqr.h b/frc971/control_loops/dlqr.h
index 57aa88b..d31f492 100644
--- a/frc971/control_loops/dlqr.h
+++ b/frc971/control_loops/dlqr.h
@@ -41,7 +41,9 @@
          ::Eigen::Matrix<double, kM, kN> *K,
          ::Eigen::Matrix<double, kN, kN> *S) {
   *K = ::Eigen::Matrix<double, kM, kN>::Zero();
-  *S = ::Eigen::Matrix<double, kN, kN>::Zero();
+  if (S != nullptr) {
+    *S = ::Eigen::Matrix<double, kN, kN>::Zero();
+  }
   // Discrete (not continuous)
   char DICO = 'D';
   // B and R are provided instead of G.
diff --git a/frc971/control_loops/swerve/BUILD b/frc971/control_loops/swerve/BUILD
index 2c3ad50..6be2604 100644
--- a/frc971/control_loops/swerve/BUILD
+++ b/frc971/control_loops/swerve/BUILD
@@ -14,8 +14,17 @@
 )
 
 static_flatbuffer(
+    name = "swerve_drivetrain_joystick_goal_fbs",
+    srcs = ["swerve_drivetrain_joystick_goal.fbs"],
+)
+
+static_flatbuffer(
     name = "swerve_drivetrain_goal_fbs",
     srcs = ["swerve_drivetrain_goal.fbs"],
+    deps = [
+        ":swerve_drivetrain_joystick_goal_fbs",
+        "//frc971/math:matrix_fbs",
+    ],
 )
 
 static_flatbuffer(
@@ -292,6 +301,33 @@
     ],
 )
 
+py_binary(
+    name = "experience_collector",
+    srcs = [
+        "experience_collector.py",
+    ],
+    deps = [
+        ":casadi_velocity_mpc_lib",
+        ":jax_dynamics",
+        "//frc971/control_loops/swerve/velocity_controller:physics",
+        "@pip//absl_py",
+        "@pip//matplotlib",
+        "@pip//numpy",
+        "@pip//pygobject",
+        "@pip//scipy",
+        "@pip//tensorflow",
+    ],
+)
+
+py_binary(
+    name = "multi_experience_collector",
+    srcs = ["multi_experience_collector.py"],
+    data = [":experience_collector"],
+    deps = [
+        "@pip//absl_py",
+    ],
+)
+
 py_library(
     name = "physics_test_utils",
     srcs = [
@@ -371,3 +407,92 @@
         "@com_google_absl//absl/log:check",
     ],
 )
+
+cc_library(
+    name = "linearization_utils",
+    hdrs = ["linearization_utils.h"],
+)
+
+cc_library(
+    name = "linearized_controller",
+    hdrs = ["linearized_controller.h"],
+    deps = [
+        ":eigen_dynamics",
+        ":linearization_utils",
+        "//frc971/control_loops:c2d",
+        "//frc971/control_loops:dlqr",
+        "//frc971/control_loops:jacobian",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+    ],
+)
+
+cc_test(
+    name = "linearized_controller_test",
+    srcs = ["linearized_controller_test.cc"],
+    deps = [
+        ":linearized_controller",
+        "//aos/testing:googletest",
+    ],
+)
+
+cc_library(
+    name = "auto_diff_jacobian",
+    hdrs = ["auto_diff_jacobian.h"],
+    target_compatible_with = ["@platforms//os:linux"],
+    deps = [
+        "@com_google_ceres_solver//:ceres",
+    ],
+)
+
+cc_test(
+    name = "auto_diff_jacobian_test",
+    srcs = ["auto_diff_jacobian_test.cc"],
+    deps = [
+        ":auto_diff_jacobian",
+        "//aos/testing:googletest",
+    ],
+)
+
+cc_library(
+    name = "simplified_dynamics",
+    hdrs = ["simplified_dynamics.h"],
+    deps = [
+        ":auto_diff_jacobian",
+        ":eigen_dynamics",
+        ":motors",
+        "//aos/util:math",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+    ],
+)
+
+cc_test(
+    name = "simplified_dynamics_test",
+    srcs = ["simplified_dynamics_test.cc"],
+    deps = [
+        ":simplified_dynamics",
+        "//aos/testing:googletest",
+        "//aos/time",
+        "//frc971/control_loops:jacobian",
+        "@com_google_absl//absl/log",
+    ],
+)
+
+cc_library(
+    name = "inverse_kinematics",
+    hdrs = ["inverse_kinematics.h"],
+    deps = [
+        ":simplified_dynamics",
+        "//aos/util:math",
+    ],
+)
+
+cc_test(
+    name = "inverse_kinematics_test",
+    srcs = ["inverse_kinematics_test.cc"],
+    deps = [
+        ":inverse_kinematics",
+        "//aos/testing:googletest",
+    ],
+)
diff --git a/frc971/control_loops/swerve/auto_diff_jacobian.h b/frc971/control_loops/swerve/auto_diff_jacobian.h
new file mode 100644
index 0000000..2e7947a
--- /dev/null
+++ b/frc971/control_loops/swerve/auto_diff_jacobian.h
@@ -0,0 +1,62 @@
+#ifndef FRC971_CONTROL_LOOPS_SWERVE_AUTO_DIFF_JACOBIAN_H_
+#define FRC971_CONTROL_LOOPS_SWERVE_AUTO_DIFF_JACOBIAN_H_
+#include "include/ceres/tiny_solver.h"
+#include "include/ceres/tiny_solver_autodiff_function.h"
+
+namespace frc971::control_loops::swerve {
+// Class to conveniently scope a function that makes use of Ceres'
+// autodifferentiation methods to calculate the jacobian of the provided method.
+// Template parameters:
+// Scalar: scalar type to use (typically double or float; this is used for
+//    allowing you to control what precision you use; you cannot use this
+//    with ceres Jets because this has to use Jets internally).
+// Function: The type of the function itself. A Function f must be callable as
+//    Eigen::Matrix<Scalar, kNumOutputs, 1> = f(Eigen::Matrix<Scalar,
+//    kNumInputs, 1>{});
+template <typename Scalar, typename Function, size_t kNumInputs,
+          size_t kNumOutputs>
+class AutoDiffJacobian {
+ public:
+  // Calculates the jacobian of the provided method, function, at the provided
+  // input X.
+  static Eigen::Matrix<Scalar, kNumOutputs, kNumInputs> Jacobian(
+      const Function &function, const Eigen::Matrix<Scalar, kNumInputs, 1> &X) {
+    AutoDiffCeresFunctor ceres_functor(function);
+    TinySolverFunctor tiny_solver(ceres_functor);
+    // residual is unused, it's just a place to store the evaluated function at
+    // the current state/input.
+    Eigen::Matrix<Scalar, kNumOutputs, 1> residual;
+    Eigen::Matrix<Scalar, kNumOutputs, kNumInputs> jacobian;
+    tiny_solver(X.data(), residual.data(), jacobian.data());
+    return jacobian;
+  }
+
+ private:
+  // Borrow the TinySolver's auto-differentiation execution for use here to
+  // calculate the linearized dynamics. We aren't actually doing any solving,
+  // just letting it do the jacobian calculation for us.
+  // As such, construct a "residual" function whose residuals are just the
+  // derivative of the state and whose parameters are the stacked state + input.
+  class AutoDiffCeresFunctor {
+   public:
+    AutoDiffCeresFunctor(const Function &function) : function_(function) {}
+    template <typename ScalarT>
+    bool operator()(const ScalarT *const parameters,
+                    ScalarT *const residuals) const {
+      const Eigen::Map<const Eigen::Matrix<ScalarT, kNumInputs, 1>>
+          eigen_parameters(parameters);
+      Eigen::Map<Eigen::Matrix<ScalarT, kNumOutputs, 1>> eigen_residuals(
+          residuals);
+      eigen_residuals = function_(eigen_parameters);
+      return true;
+    }
+
+   private:
+    const Function &function_;
+  };
+  typedef ceres::TinySolverAutoDiffFunction<AutoDiffCeresFunctor, kNumOutputs,
+                                            kNumInputs, Scalar>
+      TinySolverFunctor;
+};
+}  // namespace frc971::control_loops::swerve
+#endif  // FRC971_CONTROL_LOOPS_SWERVE_AUTO_DIFF_JACOBIAN_H_
diff --git a/frc971/control_loops/swerve/auto_diff_jacobian_test.cc b/frc971/control_loops/swerve/auto_diff_jacobian_test.cc
new file mode 100644
index 0000000..d04726c
--- /dev/null
+++ b/frc971/control_loops/swerve/auto_diff_jacobian_test.cc
@@ -0,0 +1,23 @@
+#include "frc971/control_loops/swerve/auto_diff_jacobian.h"
+
+#include <functional>
+
+#include "absl/log/log.h"
+#include "gtest/gtest.h"
+
+namespace frc971::control_loops::swerve::testing {
+struct TestFunction {
+  template <typename Scalar>
+  Eigen::Matrix<Scalar, 3, 1> operator()(
+      const Eigen::Map<const Eigen::Matrix<Scalar, 2, 1>> X) const {
+    return Eigen::Matrix<double, 3, 2>{{1, 2}, {3, 4}, {5, 6}} * X;
+  }
+};
+
+TEST(AutoDiffJacobianTest, EvaluatesJacobian) {
+  EXPECT_EQ((AutoDiffJacobian<double, TestFunction, 2, 3>::Jacobian(
+                TestFunction{}, Eigen::Vector2d::Zero())),
+            (Eigen::Matrix<double, 3, 2>{{1, 2}, {3, 4}, {5, 6}}));
+}
+
+}  // namespace frc971::control_loops::swerve::testing
diff --git a/frc971/control_loops/swerve/casadi_velocity_mpc.py b/frc971/control_loops/swerve/casadi_velocity_mpc.py
index d62f7ed..54e33f1 100644
--- a/frc971/control_loops/swerve/casadi_velocity_mpc.py
+++ b/frc971/control_loops/swerve/casadi_velocity_mpc.py
@@ -24,11 +24,6 @@
 flags.DEFINE_bool('pickle', False, 'Write optimization results.')
 flags.DEFINE_string('outputdir', None, 'Directory to write problem results to')
 
-# Full print level on ipopt. Piping to a file and using a filter or search method is suggested
-# grad_x prints out the gradient at each iteration in the following sequence: U0, X1, U1, etc.
-flags.DEFINE_bool('full_debug', False,
-                  'If true, turn on all the debugging in the solver.')
-
 
 class Solver(object):
 
diff --git a/frc971/control_loops/swerve/casadi_velocity_mpc_lib.py b/frc971/control_loops/swerve/casadi_velocity_mpc_lib.py
index f464268..e358422 100644
--- a/frc971/control_loops/swerve/casadi_velocity_mpc_lib.py
+++ b/frc971/control_loops/swerve/casadi_velocity_mpc_lib.py
@@ -3,11 +3,19 @@
 from frc971.control_loops.swerve import dynamics
 import casadi
 import numpy
+from absl import flags
+
+FLAGS = flags.FLAGS
+
+# Full print level on ipopt. Piping to a file and using a filter or search method is suggested
+# grad_x prints out the gradient at each iteration in the following sequence: U0, X1, U1, etc.
+flags.DEFINE_bool('full_debug', False,
+                  'If true, turn on all the debugging in the solver.')
 
 
 class MPC(object):
 
-    def __init__(self, solver='fatrop', jit=True):
+    def __init__(self, solver='fatrop', jit=True, N=200):
         self.fdot = dynamics.swerve_full_dynamics(
             casadi.SX.sym("X", dynamics.NUM_STATES, 1),
             casadi.SX.sym("U", 8, 1))
@@ -47,7 +55,7 @@
         self.next_X = self.make_physics()
         self.cost = self.make_cost()
 
-        self.N = 200
+        self.N = N
 
         # Start with an empty nonlinear program.
         self.w = []
diff --git a/frc971/control_loops/swerve/experience_collector.py b/frc971/control_loops/swerve/experience_collector.py
new file mode 100644
index 0000000..b4c17fb
--- /dev/null
+++ b/frc971/control_loops/swerve/experience_collector.py
@@ -0,0 +1,201 @@
+#!/usr/bin/env python3
+import os, sys
+
+# Setup XLA first.
+os.environ['XLA_FLAGS'] = ' '.join([
+    # Teach it where to find CUDA
+    '--xla_gpu_cuda_data_dir=/usr/lib/cuda',
+    # Use up to 20 cores
+    #'--xla_force_host_platform_device_count=6',
+    # Dump XLA to /tmp/foo to aid debugging
+    #'--xla_dump_to=/tmp/foo',
+    #'--xla_gpu_enable_command_buffer='
+    # Dump sharding
+    #"--xla_dump_to=/tmp/foo",
+    #"--xla_dump_hlo_pass_re=spmd|propagation"
+])
+os.environ['JAX_PLATFORMS'] = 'cpu'
+os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
+os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.50'
+
+from absl import flags
+from absl import app
+from absl import logging
+import pickle
+import numpy
+from frc971.control_loops.swerve import dynamics
+from frc971.control_loops.swerve.casadi_velocity_mpc_lib import MPC
+import jax
+import tensorflow as tf
+from frc971.control_loops.swerve.velocity_controller.physics import SwerveProblem
+from frc971.control_loops.swerve import jax_dynamics
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_bool('compileonly', False,
+                  'If true, load casadi, don\'t compile it')
+
+flags.DEFINE_float('vx', 1.0, 'Goal velocity in m/s in x')
+flags.DEFINE_float('vy', 0.0, 'Goal velocity in m/s in y')
+flags.DEFINE_float('omega', 0.0, 'Goal velocity in m/s in omega')
+flags.DEFINE_integer('seed', 0, 'Seed for random initial state.')
+
+flags.DEFINE_bool('save_plots', True,
+                  'If true, save plots for each run as well.')
+flags.DEFINE_string('outputdir', None, 'Directory to write problem results to')
+flags.DEFINE_bool('quiet', False, 'If true, print a lot less')
+
+flags.DEFINE_integer('num_solutions', 100,
+                     'Number of random problems to solve.')
+flags.DEFINE_integer('horizon', 200, 'Horizon to solve for')
+
+try:
+    from matplotlib import pylab
+except ModuleNotFoundError:
+    pass
+
+
+def collect_experience(problem, mpc, rng):
+    X_initial = numpy.array(problem.random_states(rng,
+                                                  dimensions=1)).transpose()
+
+    R_goal = numpy.zeros((3, 1))
+    R_goal[0, 0] = FLAGS.vx
+    R_goal[1, 0] = FLAGS.vy
+    R_goal[2, 0] = FLAGS.omega
+
+    solution = mpc.solve(p=numpy.vstack((X_initial, R_goal)))
+    sys.stderr.flush()
+    sys.stdout.flush()
+
+    # Solver doesn't solve for the last state.  So we get N-1 states back.
+    experience = {
+        'observations1': numpy.zeros((mpc.N - 1, problem.num_states)),
+        'observations2': numpy.zeros((mpc.N - 1, problem.num_states)),
+        'actions': numpy.zeros((mpc.N - 1, problem.num_outputs)),
+        'rewards': numpy.zeros((mpc.N - 1, 1)),
+        'goals': numpy.zeros((mpc.N - 1, problem.num_goals)),
+    }
+
+    if not FLAGS.quiet:
+        print('x(0):', X_initial.transpose())
+
+    logging.info('Finished solving')
+    X_prior = X_initial.squeeze()
+    for j in range(mpc.N - 1):
+        if not FLAGS.quiet:
+            print(f'u({j}): ', mpc.unpack_u(solution, j))
+            print(f'x({j+1}): ', mpc.unpack_x(solution, j + 1))
+        experience['observations1'][j, :] = X_prior
+        X_prior = mpc.unpack_x(solution, j + 1)
+        experience['observations2'][j, :] = X_prior
+        experience['actions'][j, :] = mpc.unpack_u(solution, j)
+        experience['goals'][j, :] = R_goal[:, 0]
+
+    logging.info('Finished all but reward')
+    for j in range(mpc.N - 1):
+        experience['rewards'][j, :] = problem.reward(
+            X=X_prior,
+            U=mpc.unpack_u(solution, j),
+            goal=R_goal[:, 0],
+        )
+    sys.stderr.flush()
+    sys.stdout.flush()
+
+    return experience
+
+
+def save_experience(problem, mpc, experience, experience_number):
+    with open(f'experience_{experience_number}.pkl', 'wb') as f:
+        pickle.dump(experience, f)
+
+    if not FLAGS.save_plots:
+        return
+
+    fig0, axs0 = pylab.subplots(3)
+    fig1, axs1 = pylab.subplots(2)
+
+    axs0[0].clear()
+    axs0[1].clear()
+    axs0[2].clear()
+
+    t = problem.dt * numpy.array(list(range(mpc.N - 1)))
+
+    X_plot = experience['observations1']
+    U_plot = experience['actions']
+
+    axs0[0].plot(t, X_plot[:, dynamics.VELOCITY_STATE_VX], label="vx")
+    axs0[0].plot(t, X_plot[:, dynamics.VELOCITY_STATE_VY], label="vy")
+    axs0[0].plot(t, X_plot[:, dynamics.VELOCITY_STATE_OMEGA], label="omega")
+    axs0[0].legend()
+
+    axs0[1].plot(t, U_plot[:, 0], label="Is0")
+    axs0[1].plot(t, U_plot[:, 2], label="Is1")
+    axs0[1].plot(t, U_plot[:, 4], label="Is2")
+    axs0[1].plot(t, U_plot[:, 6], label="Is3")
+    axs0[1].legend()
+
+    axs0[2].plot(t, U_plot[:, 1], label="Id0")
+    axs0[2].plot(t, U_plot[:, 3], label="Id1")
+    axs0[2].plot(t, U_plot[:, 5], label="Id2")
+    axs0[2].plot(t, U_plot[:, 7], label="Id3")
+    axs0[2].legend()
+
+    axs1[0].clear()
+    axs1[1].clear()
+
+    axs1[0].plot(t, X_plot[:, dynamics.VELOCITY_STATE_THETAS0], label='steer0')
+    axs1[0].plot(t, X_plot[:, dynamics.VELOCITY_STATE_THETAS1], label='steer1')
+    axs1[0].plot(t, X_plot[:, dynamics.VELOCITY_STATE_THETAS2], label='steer2')
+    axs1[0].plot(t, X_plot[:, dynamics.VELOCITY_STATE_THETAS3], label='steer3')
+    axs1[0].legend()
+    axs1[1].plot(t,
+                 X_plot[:, dynamics.VELOCITY_STATE_OMEGAS0],
+                 label='steer_velocity0')
+    axs1[1].plot(t,
+                 X_plot[:, dynamics.VELOCITY_STATE_OMEGAS1],
+                 label='steer_velocity1')
+    axs1[1].plot(t,
+                 X_plot[:, dynamics.VELOCITY_STATE_OMEGAS2],
+                 label='steer_velocity2')
+    axs1[1].plot(t,
+                 X_plot[:, dynamics.VELOCITY_STATE_OMEGAS3],
+                 label='steer_velocity3')
+    axs1[1].legend()
+
+    fig0.savefig(f'state_{experience_number}.svg')
+    fig1.savefig(f'steer_{experience_number}.svg')
+
+    # Free the memory associated with the figures.
+    fig0.clf()
+    fig1.clf()
+    pylab.close(fig0)
+    pylab.close(fig1)
+
+
+def main(argv):
+    if FLAGS.outputdir:
+        os.chdir(FLAGS.outputdir)
+
+    # Hide any GPUs from TensorFlow. Otherwise it might reserve memory.
+    tf.config.experimental.set_visible_devices([], 'GPU')
+    rng = jax.random.key(FLAGS.seed)
+
+    physics_constants = jax_dynamics.Coefficients()
+    problem = SwerveProblem(physics_constants)
+    mpc = MPC(solver='ipopt', N=(FLAGS.horizon + 1))
+
+    if FLAGS.compileonly:
+        return
+
+    for i in range(FLAGS.num_solutions):
+        rng, rng_init = jax.random.split(rng)
+        experience = collect_experience(problem, mpc, rng_init)
+        logging.info('Solved problem %d', i)
+
+        save_experience(problem, mpc, experience, i)
+        logging.info('Wrote problem %d', i)
+
+
+if __name__ == '__main__':
+    app.run(main)
diff --git a/frc971/control_loops/swerve/generate_physics.cc b/frc971/control_loops/swerve/generate_physics.cc
index de52d79..8a7fe16 100644
--- a/frc971/control_loops/swerve/generate_physics.cc
+++ b/frc971/control_loops/swerve/generate_physics.cc
@@ -339,7 +339,7 @@
     result_h.emplace_back(
         "inline constexpr size_t kNumVelocityStates = "
         "static_cast<size_t>(VelocityStates::kNumStates);");
-    result_h.emplace_back("struct Inputs {");
+    result_h.emplace_back("struct InputStates {");
     result_h.emplace_back("enum States {");
     result_h.emplace_back("  kIs0 = 0,");
     result_h.emplace_back("  kId0 = 1,");
@@ -354,7 +354,7 @@
     result_h.emplace_back("};");
     result_h.emplace_back(
         "inline constexpr size_t kNumInputs = "
-        "static_cast<size_t>(Inputs::kNumInputs);");
+        "static_cast<size_t>(InputStates::kNumInputs);");
     result_h.emplace_back("");
     result_h.emplace_back("// Returns the derivative of our state vector");
     result_h.emplace_back(
@@ -365,29 +365,6 @@
     result_h.emplace_back(
         "    Eigen::Ref<const Eigen::Matrix<double, kNumInputs, 1>> U);");
     result_h.emplace_back("");
-    result_h.emplace_back(
-        "Eigen::Matrix<double, kNumVelocityStates, 1> ToVelocityState(");
-    result_h.emplace_back(
-        "    Eigen::Ref<const Eigen::Matrix<double, kNumFullDynamicsStates, "
-        "1>> X);");
-    result_h.emplace_back("");
-    result_h.emplace_back(
-        "Eigen::Matrix<double, kNumFullDynamicsStates, 1> FromVelocityState(");
-    result_h.emplace_back(
-        "    Eigen::Ref<const Eigen::Matrix<double, kNumVelocityStates, 1>> "
-        "X);");
-    result_h.emplace_back("");
-    result_h.emplace_back(
-        "inline Eigen::Matrix<double, kNumVelocityStates, 1> VelocityPhysics(");
-    result_h.emplace_back(
-        "    Eigen::Ref<const Eigen::Matrix<double, kNumVelocityStates, 1>> "
-        "X,");
-    result_h.emplace_back(
-        "    Eigen::Ref<const Eigen::Matrix<double, kNumInputs, 1>> U) {");
-    result_h.emplace_back(
-        "  return ToVelocityState(SwervePhysics(FromVelocityState(X), U));");
-    result_h.emplace_back("}");
-    result_h.emplace_back("");
     result_h.emplace_back("}  // namespace frc971::control_loops::swerve");
     result_h.emplace_back("");
     result_h.emplace_back(absl::Substitute("#endif  // $0_", include_guard));
@@ -401,40 +378,6 @@
     result_cc.emplace_back("namespace frc971::control_loops::swerve {");
     result_cc.emplace_back("");
     result_cc.emplace_back(
-        "Eigen::Matrix<double, kNumVelocityStates, 1> ToVelocityState(");
-    result_cc.emplace_back(
-        "    Eigen::Ref<const Eigen::Matrix<double, kNumFullDynamicsStates, "
-        "1>> X) {");
-    result_cc.emplace_back(
-        "    Eigen::Matrix<double, kNumVelocityStates, 1> velocity;");
-    const std::vector<std::string_view> velocity_states = {
-        "kThetas0", "kOmegas0", "kThetas1", "kOmegas1", "kThetas2", "kOmegas2",
-        "kThetas3", "kOmegas3", "kTheta",   "kVx",      "kVy",      "kOmega"};
-    for (const std::string_view velocity_state : velocity_states) {
-      result_cc.emplace_back(absl::StrFormat(
-          "  velocity(VelocityStates::%s) = X(FullDynamicsStates::%s);",
-          velocity_state, velocity_state));
-    }
-    result_cc.emplace_back("  return velocity;");
-    result_cc.emplace_back("}");
-    result_cc.emplace_back("");
-    result_cc.emplace_back(
-        "Eigen::Matrix<double, kNumFullDynamicsStates, 1> FromVelocityState(");
-    result_cc.emplace_back(
-        "    Eigen::Ref<const Eigen::Matrix<double, kNumVelocityStates, 1>> X) "
-        "{");
-    result_cc.emplace_back(
-        "    Eigen::Matrix<double, kNumFullDynamicsStates, 1> full;");
-    result_cc.emplace_back("    full.setZero();");
-    for (const std::string_view velocity_state : velocity_states) {
-      result_cc.emplace_back(absl::StrFormat(
-          "  full(FullDynamicsStates::%s) = X(VelocityStates::%s);",
-          velocity_state, velocity_state));
-    }
-    result_cc.emplace_back("  return full;");
-    result_cc.emplace_back("}");
-    result_cc.emplace_back("");
-    result_cc.emplace_back(
         "Eigen::Matrix<double, kNumFullDynamicsStates, 1> SwervePhysics(");
     result_cc.emplace_back(
         "    Eigen::Ref<const Eigen::Matrix<double, kNumFullDynamicsStates, "
diff --git a/frc971/control_loops/swerve/inverse_kinematics.h b/frc971/control_loops/swerve/inverse_kinematics.h
new file mode 100644
index 0000000..a4e43e1
--- /dev/null
+++ b/frc971/control_loops/swerve/inverse_kinematics.h
@@ -0,0 +1,86 @@
+#ifndef FRC971_CONTROL_LOOPS_SWERVE_INVERSE_KINEMATICS_H_
+#define FRC971_CONTROL_LOOPS_SWERVE_INVERSE_KINEMATICS_H_
+#include "aos/util/math.h"
+#include "frc971/control_loops/swerve/simplified_dynamics.h"
+namespace frc971::control_loops::swerve {
+// Class to do straightforwards inverse kinematics of a swerve drivebase. This
+// is meant largely as a sanity-check/initializer for more sophisticated
+// methods. This calculates which directions the modules must be pointed to
+// cause them to be pointed directly along the direction of motion of the
+// drivebase. Accounting for slip angles and the such must be done as part of
+// more sophisticated inverse dynamics.
+template <typename Scalar>
+class InverseKinematics {
+ public:
+  using ModuleParams = SimplifiedDynamics<Scalar>::ModuleParams;
+  using Parameters = SimplifiedDynamics<Scalar>::Parameters;
+  using States = SimplifiedDynamics<Scalar>::States;
+  using State = SimplifiedDynamics<Scalar>::template VelocityState<Scalar>;
+  InverseKinematics(const Parameters &params) : params_(params) {}
+
+  // Uses kVx, kVy, kTheta, and kOmega from the input goal state for the
+  // absolute kinematics. Also uses the specified theta values to bias theta
+  // output values towards the current state (i.e., if the module 0 theta is
+  // currently 0 and we are asked to drive straight backwards, this will prefer
+  // a theta of zero rather than a theta of pi).
+  State Solve(const State &goal) {
+    State result = goal;
+    for (size_t module_index = 0; module_index < params_.modules.size();
+         ++module_index) {
+      SolveModule(goal, params_.modules[module_index].position,
+                  &result(States::kThetas0 + 2 * module_index),
+                  &result(States::kOmegas0 + 2 * module_index));
+    }
+    return result;
+  }
+
+  void SolveModule(const State &goal,
+                   const Eigen::Matrix<Scalar, 2, 1> &module_position,
+                   Scalar *module_theta, Scalar *module_omega) {
+    const Scalar vx = goal(States::kVx);
+    const Scalar vy = goal(States::kVy);
+    const Scalar omega = goal(States::kOmega);
+    // module_velocity_in_robot_frame = R(-theta) * robot_vel +
+    //     omega.cross(module_position);
+    // module_vel_x = (cos(-theta) * vx - sin(-theta) * vy) - omega * module_y
+    // module_vel_y = (sin(-theta) * vx + cos(-theta) * vy) + omega * module_x
+    // module_theta = atan2(module_vel_y, module_vel_x)
+    // module_omega = datan2(module_vel_y, module_vel_x) / dt
+    // datan2(y, x) / dt = (x * dy/dt - y * dx / dt) / (x^2 + y^2)
+    // robot accelerations are assumed to be zero.
+    // dmodule_vel_x / dt = (sin(-theta) * vx + cos(-theta) * vy) * omega
+    // dmodule_vel_y / dt = (-cos(-theta) * vx + sin(-theta) * vy) * omega
+    const Scalar ctheta = std::cos(-goal(States::kTheta));
+    const Scalar stheta = std::sin(-goal(States::kTheta));
+    const Scalar module_vel_x =
+        (ctheta * vx - stheta * vy) - omega * module_position.y();
+    const Scalar module_vel_y =
+        (stheta * vx + ctheta * vy) + omega * module_position.x();
+    const Scalar nominal_module_theta = atan2(module_vel_y, module_vel_x);
+    // If the current module theta is more than 90 deg from the desired theta,
+    // flip the desired theta by 180 deg.
+    if (std::abs(aos::math::DiffAngle(nominal_module_theta, *module_theta)) >
+        M_PI_2) {
+      *module_theta = aos::math::NormalizeAngle(nominal_module_theta + M_PI);
+    } else {
+      *module_theta = nominal_module_theta;
+    }
+    const Scalar module_accel_x = (stheta * vx + ctheta * vy) * omega;
+    const Scalar module_accel_y = (-ctheta * vx + stheta * vy) * omega;
+    const Scalar module_vel_norm_squared =
+        (module_vel_x * module_vel_x + module_vel_y * module_vel_y);
+    if (module_vel_norm_squared < 1e-5) {
+      // Prevent poor conditioning of module velocities at near-zero speeds.
+      *module_omega = 0.0;
+    } else {
+      *module_omega =
+          (module_vel_x * module_accel_y - module_vel_y * module_accel_x) /
+          module_vel_norm_squared;
+    }
+  }
+
+ private:
+  Parameters params_;
+};
+}  // namespace frc971::control_loops::swerve
+#endif  // FRC971_CONTROL_LOOPS_SWERVE_INVERSE_KINEMATICS_H_
diff --git a/frc971/control_loops/swerve/inverse_kinematics_test.cc b/frc971/control_loops/swerve/inverse_kinematics_test.cc
new file mode 100644
index 0000000..407d540
--- /dev/null
+++ b/frc971/control_loops/swerve/inverse_kinematics_test.cc
@@ -0,0 +1,150 @@
+#include "frc971/control_loops/swerve/inverse_kinematics.h"
+
+#include "gtest/gtest.h"
+
+namespace frc971::control_loops::swerve::testing {
+class InverseKinematicsTest : public ::testing::Test {
+ protected:
+  typedef double Scalar;
+  using State = InverseKinematics<Scalar>::State;
+  using States = InverseKinematics<Scalar>::States;
+  using ModuleParams = InverseKinematics<Scalar>::ModuleParams;
+  using Parameters = InverseKinematics<Scalar>::Parameters;
+  static ModuleParams MakeModule(const Eigen::Matrix<Scalar, 2, 1> &position) {
+    return ModuleParams{.position = position,
+                        .slip_angle_coefficient = 0.0,
+                        .slip_angle_alignment_coefficient = 0.0,
+                        .steer_motor = KrakenFOC(),
+                        .drive_motor = KrakenFOC(),
+                        .steer_ratio = 1.0,
+                        .drive_ratio = 1.0};
+  }
+  static Parameters MakeParams() {
+    return {.mass = 1.0,
+            .moment_of_inertia = 1.0,
+            .modules = {
+                MakeModule({1.0, 1.0}),
+                MakeModule({-1.0, 1.0}),
+                MakeModule({-1.0, -1.0}),
+                MakeModule({1.0, -1.0}),
+            }};
+  }
+
+  InverseKinematicsTest() : inverse_kinematics_(MakeParams()) {}
+
+  struct Goal {
+    Scalar vx;
+    Scalar vy;
+    Scalar omega;
+    Scalar theta;
+  };
+
+  void CheckState(
+      const Goal &goal, const std::array<Scalar, 4> &expected_thetas,
+      const std::optional<Eigen::Vector4d> &expected_omegas = std::nullopt) {
+    State goal_state = State::Zero();
+    goal_state(States::kVx) = goal.vx;
+    goal_state(States::kVy) = goal.vy;
+    goal_state(States::kOmega) = goal.omega;
+    goal_state(States::kTheta) = goal.theta;
+    SCOPED_TRACE(goal_state.bottomRows<4>().transpose());
+    const State nominal_state = inverse_kinematics_.Solve(goal_state);
+    // Now, calculate the numerical derivative of the state and validate that it
+    // matches expectations.
+    const Scalar kDt = 1e-5;
+    const Scalar dtheta = kDt * goal.omega;
+    goal_state(States::kTheta) += dtheta / 2.0;
+    const State state_eps_pos = inverse_kinematics_.Solve(goal_state);
+    goal_state(States::kTheta) -= dtheta;
+    const State state_eps_neg = inverse_kinematics_.Solve(goal_state);
+    const State state_derivative = (state_eps_pos - state_eps_neg) / kDt;
+    for (size_t module_index = 0; module_index < 4; ++module_index) {
+      SCOPED_TRACE(module_index);
+      const int omega_idx = States::kOmegas0 + 2 * module_index;
+      const int theta_idx = States::kThetas0 + 2 * module_index;
+      EXPECT_NEAR(nominal_state(omega_idx), state_derivative(theta_idx), 1e-10);
+      EXPECT_NEAR(nominal_state(theta_idx), expected_thetas[module_index],
+                  1e-10);
+      if (expected_omegas.has_value()) {
+        EXPECT_NEAR(nominal_state(omega_idx),
+                    expected_omegas.value()(module_index), 1e-10);
+      }
+    }
+  }
+
+  InverseKinematics<Scalar> inverse_kinematics_;
+};
+
+// Tests that if we are driving straight with no yaw that we get sane
+// kinematics.
+TEST_F(InverseKinematicsTest, StraightDrivingNoYaw) {
+  // Sanity-check zero-speed operation.
+  CheckState({.vx = 0.0, .vy = 0.0, .omega = 0.0, .theta = 0.0},
+             {0.0, 0.0, 0.0, 0.0}, Eigen::Vector4d::Zero());
+
+  CheckState({.vx = 1.0, .vy = 0.0, .omega = 0.0, .theta = 0.0},
+             {0.0, 0.0, 0.0, 0.0}, Eigen::Vector4d::Zero());
+  // Reverse should prefer to bias the modules towards [-pi/2, pi/2] due to
+  // hysteresis from starting the modules at thetas of 0.
+  CheckState({.vx = -1.0, .vy = 0.0, .omega = 0.0, .theta = 0.0},
+             {0.0, 0.0, 0.0, 0.0}, Eigen::Vector4d::Zero());
+
+  CheckState({.vx = 0.0, .vy = 1.0, .omega = 0.0, .theta = 0.0},
+             {M_PI_2, M_PI_2, M_PI_2, M_PI_2}, Eigen::Vector4d::Zero());
+  // For module hysteresis, this is a corner case where we are exactly 90 deg
+  // from the current value; the exact result is unimportant.
+  CheckState({.vx = 0.0, .vy = -1.0, .omega = 0.0, .theta = 0.0},
+             {-M_PI_2, -M_PI_2, -M_PI_2, -M_PI_2}, Eigen::Vector4d::Zero());
+
+  CheckState({.vx = 1.0, .vy = 1.0, .omega = 0.0, .theta = 0.0},
+             {M_PI_4, M_PI_4, M_PI_4, M_PI_4}, Eigen::Vector4d::Zero());
+  // Reverse should prefer to bias the modules towards [-pi/2, pi/2] due to
+  // hysteresis from starting the modules at thetas of 0.
+  CheckState({.vx = -1.0, .vy = -1.0, .omega = 0.0, .theta = 0.0},
+             {M_PI_4, M_PI_4, M_PI_4, M_PI_4}, Eigen::Vector4d::Zero());
+}
+
+// Tests that if we are driving straight with non-zero yaw that we get sane
+// kinematics.
+TEST_F(InverseKinematicsTest, StraightDrivingYawed) {
+  CheckState({.vx = 1.0, .vy = 0.0, .omega = 0.0, .theta = 0.1},
+             {-0.1, -0.1, -0.1, -0.1}, Eigen::Vector4d::Zero());
+
+  CheckState({.vx = 1.0, .vy = 0.0, .omega = 0.0, .theta = M_PI_2},
+             {-M_PI_2, -M_PI_2, -M_PI_2, -M_PI_2}, Eigen::Vector4d::Zero());
+  CheckState({.vx = 1.0, .vy = 0.0, .omega = 0.0, .theta = M_PI},
+             {0.0, 0.0, 0.0, 0.0}, Eigen::Vector4d::Zero());
+
+  CheckState({.vx = 0.0, .vy = 1.0, .omega = 0.0, .theta = M_PI_2},
+             {0.0, 0.0, 0.0, 0.0}, Eigen::Vector4d::Zero());
+  // Reverse should prefer to bias the modules towards [-pi/2, pi/2] due to
+  // hysteresis from starting the modules at thetas of 0.
+  CheckState({.vx = 0.0, .vy = -1.0, .omega = 0.0, .theta = M_PI_2},
+             {0.0, 0.0, 0.0, 0.0}, Eigen::Vector4d::Zero());
+}
+
+// Tests that we can spin in place.
+TEST_F(InverseKinematicsTest, SpinInPlace) {
+  CheckState({.vx = 0.0, .vy = 0.0, .omega = 1.0, .theta = 0.0},
+             {-M_PI_4, M_PI_4, -M_PI_4, M_PI_4}, Eigen::Vector4d::Zero());
+  // And changing the current theta should not matter.
+  CheckState({.vx = 0.0, .vy = 0.0, .omega = 1.0, .theta = 1.0},
+             {-M_PI_4, M_PI_4, -M_PI_4, M_PI_4}, Eigen::Vector4d::Zero());
+}
+
+// Tests that if we are spinning while moving that we correctly calculate module
+// yaw rates.
+TEST_F(InverseKinematicsTest, SpinWhileMoving) {
+  // Set up a situation where we are driving straight forwards, with a
+  // yaw rate of 1 rad / sec.
+  // The modules are all at radii of sqrt(2), so the contribution from both
+  // the translational and rotational velocities should be equal; in this case
+  // each module will have an angle that is an equal combination of straight
+  // forwards (0 deg) and some 45 deg offset.
+  CheckState({.vx = std::sqrt(static_cast<Scalar>(2)),
+              .vy = 0.0,
+              .omega = 1.0,
+              .theta = 0.0},
+             {3 * M_PI_4 / 2, -3 * M_PI_4 / 2, -M_PI_4 / 2, M_PI_4 / 2});
+}
+}  // namespace frc971::control_loops::swerve::testing
diff --git a/frc971/control_loops/swerve/jax_dynamics.py b/frc971/control_loops/swerve/jax_dynamics.py
index 19a48de..58d5fcf 100644
--- a/frc971/control_loops/swerve/jax_dynamics.py
+++ b/frc971/control_loops/swerve/jax_dynamics.py
@@ -354,6 +354,7 @@
     ])
 
 
+@jax.jit
 def mpc_cost(coefficients: CoefficientsType, X, U, goal):
     J = 0
 
diff --git a/frc971/control_loops/swerve/linearization_utils.h b/frc971/control_loops/swerve/linearization_utils.h
new file mode 100644
index 0000000..ee1275e
--- /dev/null
+++ b/frc971/control_loops/swerve/linearization_utils.h
@@ -0,0 +1,13 @@
+#ifndef FRC971_CONTROL_LOOPS_SWERVE_LINEARIZATION_UTILS_H_
+#define FRC971_CONTROL_LOOPS_SWERVE_LINEARIZATION_UTILS_H_
+
+namespace frc971::control_loops::swerve {
+template <typename State, typename Input>
+struct DynamicsInterface {
+  virtual ~DynamicsInterface() {}
+  // To be overridden by the implementation; returns the derivative of the state
+  // at the given state with the provided control input.
+  virtual State operator()(const State &X, const Input &U) const = 0;
+};
+}  // namespace frc971::control_loops::swerve
+#endif  // FRC971_CONTROL_LOOPS_SWERVE_LINEARIZATION_UTILS_H_
diff --git a/frc971/control_loops/swerve/linearized_controller.h b/frc971/control_loops/swerve/linearized_controller.h
new file mode 100644
index 0000000..1df9640
--- /dev/null
+++ b/frc971/control_loops/swerve/linearized_controller.h
@@ -0,0 +1,126 @@
+#include <memory>
+
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include <Eigen/Dense>
+
+#include "frc971/control_loops/c2d.h"
+#include "frc971/control_loops/dlqr.h"
+#include "frc971/control_loops/jacobian.h"
+#include "frc971/control_loops/swerve/dynamics.h"
+#include "frc971/control_loops/swerve/linearization_utils.h"
+
+namespace frc971::control_loops::swerve {
+
+// Provides a simple LQR controller that takes a non-linear system, linearizes
+// the dynamics at each timepoint, recalculates the LQR gains for those
+// dynamics, and calculates the relevant feedback inputs to provide.
+template <int NStates, typename Scalar = double>
+class LinearizedController {
+ public:
+  typedef Eigen::Matrix<Scalar, NStates, 1> State;
+  typedef Eigen::Matrix<Scalar, NStates, NStates> StateSquare;
+  typedef Eigen::Matrix<Scalar, kNumInputs, 1> Input;
+  typedef Eigen::Matrix<Scalar, kNumInputs, kNumInputs> InputSquare;
+  typedef Eigen::Matrix<Scalar, NStates, kNumInputs> BMatrix;
+  typedef DynamicsInterface<State, Input> Dynamics;
+
+  struct Parameters {
+    // State cost matrix.
+    StateSquare Q;
+    // Input cost matrix.
+    InputSquare R;
+    // period at which the controller is called.
+    std::chrono::nanoseconds dt;
+    // The dynamics to use.
+    // TODO(james): I wrote this before creating the auto-differentiation
+    // functions; we should swap to the auto-differentiation, since the
+    // numerical linearization is one of the bigger timesinks in this controller
+    // right now.
+    std::unique_ptr<Dynamics> dynamics;
+  };
+
+  // Represents the linearized dynamics of the system.
+  struct LinearDynamics {
+    StateSquare A;
+    BMatrix B;
+  };
+
+  // Debug information for a given cycle of the controller.
+  struct ControllerDebug {
+    // Feedforward input which we provided.
+    Input U_ff;
+    // Calculated feedback input to provide.
+    Input U_feedback;
+    Eigen::Matrix<Scalar, kNumInputs, NStates> feedback_contributions;
+  };
+
+  struct ControllerResult {
+    // Control input to provide to the robot.
+    Input U;
+    ControllerDebug debug;
+  };
+
+  LinearizedController(Parameters params) : params_(std::move(params)) {}
+
+  // Runs the controller for a given iteration, relinearizing the dynamics about
+  // the provided current state X, attempting to control the robot to the
+  // desired goal state.
+  // The U_ff input will be added into the returned control input.
+  ControllerResult RunController(const State &X, const State &goal,
+                                 Input U_ff) {
+    auto start_time = aos::monotonic_clock::now();
+    // TODO(james): Swap this to the auto-diff methods; this is currently about
+    // a third of the total time spent in this method when run on the roborio.
+    const struct LinearDynamics continuous_dynamics =
+        LinearizeDynamics(X, U_ff);
+    auto linearization_time = aos::monotonic_clock::now();
+    struct LinearDynamics discrete_dynamics;
+    frc971::controls::C2D(continuous_dynamics.A, continuous_dynamics.B,
+                          params_.dt, &discrete_dynamics.A,
+                          &discrete_dynamics.B);
+    auto c2d_time = aos::monotonic_clock::now();
+    VLOG(2) << "Controllability of dynamics (ideally should be " << NStates
+            << "): "
+            << frc971::controls::Controllability(discrete_dynamics.A,
+                                                 discrete_dynamics.B);
+    Eigen::Matrix<Scalar, kNumInputs, NStates> K;
+    Eigen::Matrix<Scalar, NStates, NStates> S;
+    // TODO(james): Swap this to a cheaper DARE solver; we should probably just
+    // do something like we do in Trajectory::CalculatePathGains for the tank
+    // spline controller where we approximate the infinite-horizon DARE solution
+    // by doing a finite-horizon LQR.
+    // Currently the dlqr call represents ~60% of the time spent in the
+    // RunController() method.
+    frc971::controls::dlqr(discrete_dynamics.A, discrete_dynamics.B, params_.Q,
+                           params_.R, &K, &S);
+    auto dlqr_time = aos::monotonic_clock::now();
+    const Input U_feedback = K * (goal - X);
+    const Input U = U_ff + U_feedback;
+    Eigen::Matrix<Scalar, kNumInputs, NStates> feedback_contributions;
+    for (int state_idx = 0; state_idx < NStates; ++state_idx) {
+      feedback_contributions.col(state_idx) =
+          K.col(state_idx) * (goal - X)(state_idx);
+    }
+    VLOG(2) << "linearization time "
+            << aos::time::DurationInSeconds(linearization_time - start_time)
+            << " c2d time "
+            << aos::time::DurationInSeconds(c2d_time - linearization_time)
+            << " dlqr time "
+            << aos::time::DurationInSeconds(dlqr_time - c2d_time);
+    return {.U = U,
+            .debug = {.U_ff = U_ff,
+                      .U_feedback = U_feedback,
+                      .feedback_contributions = feedback_contributions}};
+  }
+
+  LinearDynamics LinearizeDynamics(const State &X, const Input &U) {
+    return {.A = NumericalJacobianX(*params_.dynamics, X, U),
+            .B = NumericalJacobianU(*params_.dynamics, X, U)};
+  }
+
+ private:
+  const Parameters params_;
+};
+
+}  // namespace frc971::control_loops::swerve
diff --git a/frc971/control_loops/swerve/linearized_controller_test.cc b/frc971/control_loops/swerve/linearized_controller_test.cc
new file mode 100644
index 0000000..d5e7aee
--- /dev/null
+++ b/frc971/control_loops/swerve/linearized_controller_test.cc
@@ -0,0 +1,87 @@
+#include "frc971/control_loops/swerve/linearized_controller.h"
+
+#include "gtest/gtest.h"
+
+namespace frc971::control_loops::swerve::test {
+
+class LinearizedControllerTest : public ::testing::Test {
+ protected:
+  typedef LinearizedController<2> Controller;
+  typedef Controller::State State;
+  typedef Controller::Input Input;
+  struct LinearDynamics : public Controller::Dynamics {
+    Eigen::Vector2d operator()(
+        const Eigen::Vector2d &X,
+        const Eigen::Matrix<double, 8, 1> &U) const override {
+      return Eigen::Matrix2d{{0.0, 1.0}, {0.0, -0.01}} * X +
+             Eigen::Matrix<double, 2, 8>{
+                 {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0},
+                 {1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}} *
+                 U;
+    }
+  };
+  static Eigen::Matrix<double, kNumInputs, kNumInputs> MakeR() {
+    Eigen::Matrix<double, kNumInputs, kNumInputs> R;
+    R.setIdentity();
+    return R;
+  }
+  LinearizedControllerTest()
+      : controller_({.Q = Eigen::Matrix<double, 2, 2>{{1.0, 0.0}, {0.0, 1.0}},
+                     .R = MakeR(),
+                     .dt = std::chrono::milliseconds(10),
+                     .dynamics = std::make_unique<LinearDynamics>()}) {}
+
+  Controller controller_;
+};
+
+// Sanity check that the dynamics linearization is working correctly.
+TEST_F(LinearizedControllerTest, LinearizedDynamics) {
+  auto dynamics = controller_.LinearizeDynamics(State::Zero(), Input::Zero());
+  EXPECT_EQ(0.0, dynamics.A(0, 0));
+  EXPECT_EQ(0.0, dynamics.A(1, 0));
+  EXPECT_EQ(1.0, dynamics.A(0, 1));
+  EXPECT_EQ(-0.01, dynamics.A(1, 1));
+  // All elements of B except for (1, 0) should be exactly 0.
+  EXPECT_EQ(1.0, dynamics.B(1, 0));
+  EXPECT_EQ(1.0, dynamics.B.norm());
+}
+
+// Confirm that the generated LQR controller is able to generate correct
+// inputs when state and goal are at zero.
+TEST_F(LinearizedControllerTest, ControllerResultAtZero) {
+  auto result =
+      controller_.RunController(State::Zero(), State::Zero(), Input::Zero());
+  EXPECT_EQ(0.0, result.U.norm());
+  EXPECT_EQ(0.0, result.debug.U_ff.norm());
+  EXPECT_EQ(0.0, result.debug.U_feedback.norm());
+}
+
+// Confirm that the generated LQR controller is able to generate correct
+// inputs when state is zero and the goal is non-zero.
+TEST_F(LinearizedControllerTest, ControlToZero) {
+  auto result = controller_.RunController(State::Zero(), State{{1.0}, {0.0}},
+                                          Input::Zero());
+  EXPECT_LT(0.0, result.U(0, 0));
+  // All other U inputs should be zero.
+  EXPECT_EQ(0.0, result.U.bottomRows<7>().norm());
+  EXPECT_EQ(0.0, result.debug.U_ff.norm());
+  EXPECT_EQ(0.0,
+            (result.U - (result.debug.U_ff + result.debug.U_feedback)).norm());
+}
+
+// Confirm that the generated LQR controller is able to pass through the
+// feedforwards when we have no difference between the goal and the current
+// state.
+TEST_F(LinearizedControllerTest, ControlToNonzeroState) {
+  const State state{{1.0}, {1.0}};
+  auto result = controller_.RunController(
+      state, state,
+      Input{{1.0}, {0.0}, {0.0}, {0.0}, {0.0}, {0.0}, {0.0}, {0.0}});
+  EXPECT_EQ(1.0, result.U(0, 0));
+  // All other U inputs should be zero.
+  EXPECT_EQ(0.0, result.U.bottomRows<7>().norm());
+  EXPECT_EQ(0.0, result.debug.U_feedback.norm());
+  EXPECT_EQ(0.0,
+            (result.U - (result.debug.U_ff + result.debug.U_feedback)).norm());
+}
+}  // namespace frc971::control_loops::swerve::test
diff --git a/frc971/control_loops/swerve/multi_experience_collector.py b/frc971/control_loops/swerve/multi_experience_collector.py
new file mode 100644
index 0000000..d0551e1
--- /dev/null
+++ b/frc971/control_loops/swerve/multi_experience_collector.py
@@ -0,0 +1,54 @@
+#!/usr/bin/env python3
+from absl import app
+from absl import flags
+import sys
+from multiprocessing.pool import ThreadPool
+import pathlib
+import subprocess
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string('outdir', '/tmp/swerve', "Directory to write results to.")
+flags.DEFINE_integer('num_actors', 20, 'Number of actors to run in parallel.')
+flags.DEFINE_integer('num_solutions', 100,
+                     'Number of random problems to solve.')
+
+
+def collect_experience(agent_number):
+    filename = f'{agent_number}'
+    if FLAGS.outdir:
+        subdir = pathlib.Path(FLAGS.outdir) / filename
+    else:
+        subdir = pathlib.Path(filename)
+    subdir.mkdir(parents=True, exist_ok=True)
+
+    with open(f'{subdir.resolve()}/log', 'w') as output:
+        subprocess.check_call(
+            args=[
+                sys.executable,
+                "frc971/control_loops/swerve/experience_collector",
+                f"--seed={agent_number}",
+                f"--outputdir={subdir.resolve()}",
+                "--quiet",
+                f"--num_solutions={FLAGS.num_solutions}",
+            ],
+            stdout=output,
+            stderr=output,
+        )
+
+
+def main(argv):
+    # Load a simple problem first so we compile with less system load.  This
+    # makes it faster on a processor with frequency boosting.
+    subprocess.check_call(args=[
+        sys.executable,
+        "frc971/control_loops/swerve/experience_collector",
+        "--compileonly",
+    ])
+
+    # Try a bunch of goals now
+    with ThreadPool(FLAGS.num_actors) as pool:
+        pool.starmap(collect_experience, zip(range(FLAGS.num_actors), ))
+
+
+if __name__ == '__main__':
+    app.run(main)
diff --git a/frc971/control_loops/swerve/simplified_dynamics.h b/frc971/control_loops/swerve/simplified_dynamics.h
new file mode 100644
index 0000000..9afe636
--- /dev/null
+++ b/frc971/control_loops/swerve/simplified_dynamics.h
@@ -0,0 +1,372 @@
+#ifndef FRC971_CONTROL_LOOPS_SWERVE_SIMPLIFIED_DYNAMICS_H_
+#define FRC971_CONTROL_LOOPS_SWERVE_SIMPLIFIED_DYNAMICS_H_
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+
+#include "aos/util/math.h"
+#include "frc971/control_loops/swerve/auto_diff_jacobian.h"
+#include "frc971/control_loops/swerve/dynamics.h"
+#include "frc971/control_loops/swerve/motors.h"
+
+namespace frc971::control_loops::swerve {
+
+// Provides a simplified set of physics representing a swerve drivetrain.
+// Broadly speaking, these dynamics model:
+// * Standard motors on the drive and steer axes, with no coupling between
+//   the motors.
+// * Assume that the steer direction of each module is only influenced by
+//   the inertia of the motor rotor plus some small aligning force.
+// * Assume that torque from the drive motor is transferred directly to the
+//   carpet.
+// * Assume that a lateral force on the wheel is generated proportional to the
+//   slip angle of the wheel.
+//
+// This class is templated on a Scalar that is used to determine whether you
+// want to use double or single-precision floats for the calculations here.
+//
+// Several individual methods on this class are also templated on a LocalScalar
+// type. This is provided to allow those methods to be called with ceres Jets to
+// do autodifferentiation within various solvers/jacobian calculators.
+template <typename Scalar = double>
+class SimplifiedDynamics {
+ public:
+  struct ModuleParams {
+    // Module position relative to the center of mass of the robot.
+    Eigen::Matrix<Scalar, 2, 1> position;
+    // Coefficient dictating how much sideways force is generated at a given
+    // slip angle. Units are effectively Newtons / radian of slip.
+    Scalar slip_angle_coefficient;
+    // Coefficient dicating how much the steer wheel is forced into alignment
+    // by the motion of the wheel over the ground (i.e., we are assuming that
+    // if you push the robot along it will cause the wheels to eventually
+    // align with the direction of motion).
+    // In radians / sec^2 / radians of slip angle.
+    Scalar slip_angle_alignment_coefficient;
+    // Parameters for the steer and drive motors.
+    Motor steer_motor;
+    Motor drive_motor;
+    // radians of module steering = steer_ratio * radians of motor shaft
+    Scalar steer_ratio;
+    // meters of driving = drive_ratio * radians of motor shaft
+    Scalar drive_ratio;
+  };
+  struct Parameters {
+    // Mass of the robot, in kg.
+    Scalar mass;
+    // Moment of inertia of the robot about the yaw axis, in kg * m^2.
+    Scalar moment_of_inertia;
+    // Note: While this technically would support an arbitrary number of
+    // modules, the statically-sized state vectors do limit us to 4 modules
+    // currently, and it should not be counted on that other pieces of code will
+    // be able to support non-4-module swerves.
+    std::vector<ModuleParams> modules;
+  };
+  enum States {
+    // Thetas* and Omegas* are the yaw and yaw rate of the indicated modules.
+    // (note that we do not actually need to track drive speed per module,
+    // as with current control we have the ability to directly command torque
+    // to those motors; however, if we wished to account for the saturation
+    // limits on the motor, then we would need to have access to those states,
+    // although they can be fully derived from the robot vx, vy, theta, and
+    // omega).
+    kThetas0 = 0,
+    kOmegas0 = 1,
+    kThetas1 = 2,
+    kOmegas1 = 3,
+    kThetas2 = 4,
+    kOmegas2 = 5,
+    kThetas3 = 6,
+    kOmegas3 = 7,
+    // Robot yaw, in radians.
+    kTheta = 8,
+    // Robot speed in the global frame, in meters / sec.
+    kVx = 9,
+    kVy = 10,
+    // Robot yaw rate, in radians / sec.
+    kOmega = 11,
+    kNumVelocityStates = 12,
+    // Augmented states for doing position control.
+    // Robot X position in the global frame.
+    kX = 12,
+    // Robot Y position in the global frame.
+    kY = 13,
+    kNumPositionStates = 14,
+  };
+  using Inputs = InputStates::States;
+
+  template <typename ScalarT = Scalar>
+  using VelocityState = Eigen::Matrix<ScalarT, kNumVelocityStates, 1>;
+  template <typename ScalarT = Scalar>
+  using PositionState = Eigen::Matrix<ScalarT, kNumPositionStates, 1>;
+  template <typename ScalarT = Scalar>
+  using VelocityStateSquare =
+      Eigen::Matrix<ScalarT, kNumVelocityStates, kNumVelocityStates>;
+  template <typename ScalarT = Scalar>
+  using PositionStateSquare =
+      Eigen::Matrix<ScalarT, kNumPositionStates, kNumPositionStates>;
+  template <typename ScalarT = Scalar>
+  using PositionBMatrix =
+      Eigen::Matrix<ScalarT, kNumPositionStates, kNumInputs>;
+  template <typename ScalarT = Scalar>
+  using Input = Eigen::Matrix<ScalarT, kNumInputs, 1>;
+
+  SimplifiedDynamics(const Parameters &params) : params_(params) {
+    for (size_t module_index = 0; module_index < params_.modules.size();
+         ++module_index) {
+      module_dynamics_.emplace_back(params_, module_index);
+    }
+  }
+
+  // Returns the derivative of state for the given state and input.
+  template <typename LocalScalar>
+  PositionState<LocalScalar> Dynamics(const PositionState<LocalScalar> &state,
+                                      const Input<LocalScalar> &input) const {
+    PositionState<LocalScalar> Xdot = PositionState<LocalScalar>::Zero();
+
+    for (const ModuleDynamics &module : module_dynamics_) {
+      Xdot += module.PartialDynamics(state, input);
+    }
+
+    // And finally catch the global states:
+    Xdot(kX) = state(kVx);
+    Xdot(kY) = state(kVy);
+    Xdot(kTheta) = state(kOmega);
+
+    return Xdot;
+  }
+
+  template <typename LocalScalar>
+  VelocityState<LocalScalar> VelocityDynamics(
+      const VelocityState<LocalScalar> &state,
+      const Input<LocalScalar> &input) const {
+    PositionState<LocalScalar> input_state = PositionState<LocalScalar>::Zero();
+    input_state.template topRows<kNumVelocityStates>() = state;
+    return Dynamics(input_state, input).template topRows<kNumVelocityStates>();
+  }
+
+  std::pair<PositionStateSquare<>, PositionBMatrix<>> LinearizedDynamics(
+      const PositionState<> &state, const Input<> &input) {
+    DynamicsFunctor functor(*this);
+    Eigen::Matrix<Scalar, kNumPositionStates + kNumInputs, 1> parameters;
+    parameters.template topRows<kNumPositionStates>() = state;
+    parameters.template bottomRows<kNumInputs>() = input;
+    const Eigen::Matrix<Scalar, kNumPositionStates,
+                        kNumPositionStates + kNumInputs>
+        jacobian =
+            AutoDiffJacobian<Scalar, DynamicsFunctor,
+                             kNumPositionStates + kNumInputs,
+                             kNumPositionStates>::Jacobian(functor, parameters);
+    return {
+        jacobian.template block<kNumPositionStates, kNumPositionStates>(0, 0),
+        jacobian.template block<kNumPositionStates, kNumInputs>(
+            0, kNumPositionStates)};
+  }
+
+ private:
+  // Wrapper to provide an operator() for the dynamisc class that allows it to
+  // be used by the auto-differentiation code.
+  class DynamicsFunctor {
+   public:
+    DynamicsFunctor(const SimplifiedDynamics &dynamics) : dynamics_(dynamics) {}
+
+    template <typename LocalScalar>
+    Eigen::Matrix<LocalScalar, kNumPositionStates, 1> operator()(
+        const Eigen::Map<const Eigen::Matrix<
+            LocalScalar, kNumPositionStates + kNumInputs, 1>>
+            input) const {
+      return dynamics_.Dynamics(
+          PositionState<LocalScalar>(
+              input.template topRows<kNumPositionStates>()),
+          Input<LocalScalar>(input.template bottomRows<kNumInputs>()));
+    }
+
+   private:
+    const SimplifiedDynamics &dynamics_;
+  };
+
+  // Represents the dynamics of an individual module.
+  class ModuleDynamics {
+   public:
+    ModuleDynamics(const Parameters &robot_params, const size_t module_index)
+        : robot_params_(robot_params), module_index_(module_index) {
+      CHECK_LT(module_index_, robot_params_.modules.size());
+    }
+
+    // This returns the portions of the derivative of state that are due to the
+    // individual module. The result from this function should be able to be
+    // naively summed with the dynamics for each other module plus some global
+    // dynamics (which take care of that e.g. xdot = vx) and give you the
+    // overall dynamics of the system.
+    template <typename LocalScalar>
+    PositionState<LocalScalar> PartialDynamics(
+        const PositionState<LocalScalar> &state,
+        const Input<LocalScalar> &input) const {
+      PositionState<LocalScalar> Xdot = PositionState<LocalScalar>::Zero();
+
+      Xdot(ThetasIdx()) = state(OmegasIdx());
+
+      // Steering dynamics for an individual module assume ~zero friction,
+      // and thus ~the only inertia is from the motor rotor itself.
+      // torque_motor = stall_torque / stall_current * current
+      // accel_motor = torque_motor / motor_inertia
+      // accel_steer = accel_motor * steer_ratio
+      const Motor &steer_motor = module_params().steer_motor;
+      const LocalScalar steer_motor_accel =
+          input(IsIdx()) *
+          static_cast<Scalar>(
+              module_params().steer_ratio * steer_motor.stall_torque /
+              (steer_motor.stall_current * steer_motor.motor_inertia));
+
+      // For the impacts of the modules on the overall robot
+      // dynamics (X, Y, and theta acceleration), we calculate the forces
+      // generated by the module and then apply them. These forces come from
+      // two effects in this model:
+      // 1. Slip angle of the module (dependent on the current robot velocity &
+      //    module steer angle).
+      // 2. Drive torque from the module (dependent on the current drive
+      //    current and module steer angle).
+      // We assume no torque is generated from e.g. the wheel resisting the
+      // steering motion.
+      //
+      // clang-format off
+       //
+       // For slip angle we have:
+       // wheel_velocity = R(-theta - theta_steer) * (
+       //    robot_vel + omega.cross(R(theta) * module_position))
+       // slip_angle = -atan2(wheel_velocity)
+       // slip_force = slip_angle_coefficient * slip_angle
+       // slip_force_direction = theta + theta_steer + pi / 2
+       // force_x = slip_force * cos(slip_force_direction)
+       // force_y = slip_force * sin(slip_force_direction)
+       // accel_* = force_* / mass
+       // # And now calculate torque from slip angle.
+       // torque_vec = module_position.cross([slip_force * cos(theta_steer + pi / 2),
+       //                                     slip_force * sin(theta_steer + pi / 2),
+       //                                     0.0])
+       // torque_vec = module_position.cross([slip_force * -sin(theta_steer),
+       //                                     slip_force * cos(theta_steer),
+       //                                     0.0])
+       // robot_torque = torque_vec.z()
+       //
+       // For drive torque we have:
+       // drive_force = (drive_current * stall_torque / stall_current) / drive_ratio
+       // drive_force_direction = theta + theta_steer
+       // force_x = drive_force * cos(drive_force_direction)
+       // force_y = drive_force * sin(drive_force_direction)
+       // torque_vec = drive_force * module_position.cross([cos(theta_steer),
+       //                                                   sin(theta_steer),
+       //                                                   0.0])
+       // torque = torque_vec.z()
+       //
+      // clang-format on
+
+      const Eigen::Matrix<Scalar, 3, 1> module_position{
+          {module_params().position.x()},
+          {module_params().position.y()},
+          {0.0}};
+      const LocalScalar theta = state(kTheta);
+      const LocalScalar theta_steer = state(ThetasIdx());
+      const Eigen::Matrix<LocalScalar, 3, 1> wheel_velocity_in_global_frame =
+          Eigen::Matrix<LocalScalar, 3, 1>(state(kVx), state(kVy),
+                                           static_cast<LocalScalar>(0.0)) +
+          (Eigen::Matrix<LocalScalar, 3, 1>(static_cast<LocalScalar>(0.0),
+                                            static_cast<LocalScalar>(0.0),
+                                            state(kOmega))
+               .cross(Eigen::AngleAxis<LocalScalar>(
+                          theta, Eigen::Matrix<LocalScalar, 3, 1>::UnitZ()) *
+                      module_position));
+      const Eigen::Matrix<LocalScalar, 3, 1> wheel_velocity_in_wheel_frame =
+          Eigen::AngleAxis<LocalScalar>(
+              -theta - theta_steer, Eigen::Matrix<LocalScalar, 3, 1>::UnitZ()) *
+          wheel_velocity_in_global_frame;
+      // The complicated dynamics use some obnoxious-looking functions to
+      // try to approximate how the slip angle behaves a low speeds to better
+      // condition the dynamics. Because I couldn't be bothered to copy those
+      // dynamics, instead just bias the slip angle to zero at low speeds.
+      const LocalScalar wheel_speed = wheel_velocity_in_wheel_frame.norm();
+      const Scalar start_speed = 0.1;
+      const LocalScalar heading_truth_proportion = -expm1(
+          /*arbitrary large number=*/static_cast<Scalar>(-100.0) *
+          (wheel_speed - start_speed));
+      const LocalScalar wheel_heading =
+          (wheel_speed < start_speed)
+              ? static_cast<LocalScalar>(0.0)
+              : heading_truth_proportion *
+                    atan2(wheel_velocity_in_wheel_frame.y(),
+                          wheel_velocity_in_wheel_frame.x());
+
+      // We wrap slip_angle with a sin() not because there is actually a sin()
+      // in the real math but rather because we need to smoothly and correctly
+      // handle slip angles between pi / 2 and 3 * pi / 2.
+      const LocalScalar slip_angle = sin(-wheel_heading);
+      const LocalScalar slip_force =
+          module_params().slip_angle_coefficient * slip_angle;
+      const LocalScalar slip_force_direction =
+          theta + theta_steer + static_cast<Scalar>(M_PI_2);
+      const Eigen::Matrix<LocalScalar, 3, 1> slip_force_vec =
+          slip_force * UnitYawVector<LocalScalar>(slip_force_direction);
+      const LocalScalar slip_torque =
+          module_position
+              .cross(slip_force *
+                     UnitYawVector<LocalScalar>(theta_steer +
+                                                static_cast<Scalar>(M_PI_2)))
+              .z();
+
+      // drive torque calculations
+      const Motor &drive_motor = module_params().drive_motor;
+      const LocalScalar drive_force =
+          input(IdIdx()) * static_cast<Scalar>(drive_motor.stall_torque /
+                                               drive_motor.stall_current /
+                                               module_params().drive_ratio);
+      const Eigen::Matrix<LocalScalar, 3, 1> drive_force_vec =
+          drive_force * UnitYawVector<LocalScalar>(theta + theta_steer);
+      const LocalScalar drive_torque =
+          drive_force *
+          module_position.cross(UnitYawVector<LocalScalar>(theta_steer)).z();
+      // We add in an aligning force on the wheels primarily to help provide a
+      // bit of impetus to the controllers/solvers to discourage aggressive
+      // slip angles. If we do not include this, then the dynamics make it look
+      // like there are no losses to using extremely aggressive slip angles.
+      const LocalScalar wheel_alignment_accel =
+          -module_params().slip_angle_alignment_coefficient * slip_angle;
+
+      Xdot(OmegasIdx()) = steer_motor_accel + wheel_alignment_accel;
+      // Sum up all the forces.
+      Xdot(kVx) =
+          (slip_force_vec.x() + drive_force_vec.x()) / robot_params_.mass;
+      Xdot(kVy) =
+          (slip_force_vec.y() + drive_force_vec.y()) / robot_params_.mass;
+      Xdot(kOmega) =
+          (slip_torque + drive_torque) / robot_params_.moment_of_inertia;
+
+      return Xdot;
+    }
+
+   private:
+    template <typename LocalScalar>
+    Eigen::Matrix<LocalScalar, 3, 1> UnitYawVector(LocalScalar yaw) const {
+      return Eigen::Matrix<LocalScalar, 3, 1>{
+          {static_cast<LocalScalar>(cos(yaw))},
+          {static_cast<LocalScalar>(sin(yaw))},
+          {static_cast<LocalScalar>(0.0)}};
+    }
+    size_t ThetasIdx() const { return kThetas0 + 2 * module_index_; }
+    size_t OmegasIdx() const { return kOmegas0 + 2 * module_index_; }
+    size_t IsIdx() const { return Inputs::kIs0 + 2 * module_index_; }
+    size_t IdIdx() const { return Inputs::kId0 + 2 * module_index_; }
+
+    const ModuleParams &module_params() const {
+      return robot_params_.modules[module_index_];
+    }
+
+    const Parameters robot_params_;
+
+    const size_t module_index_;
+  };
+
+  Parameters params_;
+  std::vector<ModuleDynamics> module_dynamics_;
+};
+
+}  // namespace frc971::control_loops::swerve
+#endif  // FRC971_CONTROL_LOOPS_SWERVE_SIMPLIFIED_DYNAMICS_H_
diff --git a/frc971/control_loops/swerve/simplified_dynamics_test.cc b/frc971/control_loops/swerve/simplified_dynamics_test.cc
new file mode 100644
index 0000000..894a85b
--- /dev/null
+++ b/frc971/control_loops/swerve/simplified_dynamics_test.cc
@@ -0,0 +1,244 @@
+#include "frc971/control_loops/swerve/simplified_dynamics.h"
+
+#include <functional>
+
+#include "absl/log/log.h"
+#include "gtest/gtest.h"
+
+#include "aos/time/time.h"
+#include "frc971/control_loops/jacobian.h"
+
+namespace frc971::control_loops::swerve::testing {
+class SimplifiedDynamicsTest : public ::testing::Test {
+ protected:
+  using Dynamics = SimplifiedDynamics<double>;
+  using PositionState = Dynamics::PositionState<double>;
+  using States = Dynamics::States;
+  using Inputs = Dynamics::Inputs;
+  using Input = Dynamics::Input<double>;
+  using ModuleParams = Dynamics::ModuleParams;
+  using Parameters = Dynamics::Parameters;
+  static ModuleParams MakeModule(const Eigen::Vector2d &position,
+                                 bool wheel_alignment) {
+    return ModuleParams{
+        .position = position,
+        .slip_angle_coefficient = 200.0,
+        .slip_angle_alignment_coefficient = wheel_alignment ? 1.0 : 0.0,
+        .steer_motor = KrakenFOC(),
+        .drive_motor = KrakenFOC(),
+        .steer_ratio = 0.1,
+        .drive_ratio = 0.01};
+  }
+  static Parameters MakeParams(bool wheel_alignment) {
+    return {.mass = 60,
+            .moment_of_inertia = 2,
+            .modules = {
+                MakeModule({1.0, 1.0}, wheel_alignment),
+                MakeModule({-1.0, 1.0}, wheel_alignment),
+                MakeModule({-1.0, -1.0}, wheel_alignment),
+                MakeModule({1.0, -1.0}, wheel_alignment),
+            }};
+  }
+  SimplifiedDynamicsTest() : dynamics_(MakeParams(false)) {}
+
+  PositionState ValidateDynamics(const PositionState &state,
+                                 const Input &input) {
+    const PositionState Xdot = dynamics_.Dynamics(state, input);
+    // Sanity check simple invariants:
+    EXPECT_EQ(Xdot(Dynamics::kX), state(Dynamics::kVx));
+    EXPECT_EQ(Xdot(Dynamics::kY), state(Dynamics::kVy));
+    EXPECT_EQ(Xdot(Dynamics::kTheta), state(Dynamics::kOmega));
+
+    // Check that the dynamics linearization produces numbers that match numeric
+    // differentiation of the dynamics.
+    aos::monotonic_clock::time_point start_time = aos::monotonic_clock::now();
+    const auto linearized_dynamics = dynamics_.LinearizedDynamics(state, input);
+    const aos::monotonic_clock::duration auto_diff_time =
+        aos::monotonic_clock::now() - start_time;
+
+    start_time = aos::monotonic_clock::now();
+    auto numerical_A = NumericalJacobianX(
+        std::bind(&Dynamics::Dynamics<double>, &dynamics_,
+                  std::placeholders::_1, std::placeholders::_2),
+        state, input);
+    auto numerical_B = NumericalJacobianU(
+        std::bind(&Dynamics::Dynamics<double>, &dynamics_,
+                  std::placeholders::_1, std::placeholders::_2),
+        state, input);
+    const aos::monotonic_clock::duration numerical_time =
+        aos::monotonic_clock::now() - start_time;
+    VLOG(1) << "Autodifferentiation took " << auto_diff_time
+            << " while numerical differentiation took " << numerical_time;
+    EXPECT_LT((numerical_A - linearized_dynamics.first).norm(), 1e-6)
+        << "Numerical result:\n"
+        << numerical_A << "\nAuto-diff result:\n"
+        << linearized_dynamics.first;
+    EXPECT_LT((numerical_B - linearized_dynamics.second).norm(), 1e-6)
+        << "Numerical result:\n"
+        << numerical_B << "\nAuto-diff result:\n"
+        << linearized_dynamics.second;
+    return Xdot;
+  }
+
+  Dynamics dynamics_;
+};
+
+// Test that if all states and inputs are at zero that the robot won't move.
+TEST_F(SimplifiedDynamicsTest, ZeroIsZero) {
+  EXPECT_EQ(PositionState::Zero(),
+            ValidateDynamics(PositionState::Zero(), Input::Zero()));
+}
+
+// Test that if we are travelling straight forwards with zero inputs that we
+// just coast.
+TEST_F(SimplifiedDynamicsTest, CoastForwards) {
+  PositionState state = PositionState::Zero();
+  state(States::kVx) = 1.0;
+  PositionState expected = PositionState::Zero();
+  expected(States::kX) = 1.0;
+  EXPECT_EQ(expected, ValidateDynamics(state, Input::Zero()));
+}
+
+// Tests that we can accelerate the robot.
+TEST_F(SimplifiedDynamicsTest, AccelerateStraight) {
+  // Check that the drive currents behave as anticipated and accelerate the
+  // robot.
+  Input input{{0.0}, {1.0}, {0.0}, {1.0}, {0.0}, {1.0}, {0.0}, {1.0}};
+  PositionState state = PositionState::Zero();
+  state(States::kVx) = 0.0;
+  PositionState result = ValidateDynamics(state, input);
+  EXPECT_EQ(result.norm(), result(States::kVx));
+  EXPECT_LT(0.1, result(States::kVx));
+}
+
+// Test that if we are driving straight sideways (so our wheel are at 90
+// degrees) that we experience a force slowing us down.
+TEST_F(SimplifiedDynamicsTest, ForceWheelsSideways) {
+  PositionState state = PositionState::Zero();
+  state(States::kVy) = 1.0;
+  PositionState result = ValidateDynamics(state, Input::Zero());
+  EXPECT_LT(result.topRows<States::kVy>().norm(), 1e-10)
+      << ": All derivatives prior to the vy state should be ~zero.\n"
+      << result;
+  EXPECT_LT(result(States::kVy), -1.0)
+      << ": expected non-trivial deceleration.";
+  EXPECT_EQ(result(States::kOmega), 0.0);
+}
+
+// Tests that we can make the robot spin in place by orienting all the wheels
+TEST_F(SimplifiedDynamicsTest, SpinInPlaceNoSlip) {
+  PositionState state = PositionState::Zero();
+  state(States::kThetas0) = 3.0 * M_PI / 4.0;
+  state(States::kThetas1) = 5.0 * M_PI / 4.0;
+  state(States::kThetas2) = 7.0 * M_PI / 4.0;
+  state(States::kThetas3) = 1.0 * M_PI / 4.0;
+  state(States::kOmega) = 1.0;
+  PositionState result = ValidateDynamics(state, Input::Zero());
+  EXPECT_NEAR(result.norm(), 1.0, 1e-10)
+      << ": Only non-zero state should be kTheta, which should be exactly 1.0.";
+  EXPECT_EQ(result(States::kTheta), 1.0);
+
+  // Sanity check that when we then apply drive torque to the wheels that that
+  // apins the robot more.
+  Input input{{0.0}, {1.0}, {0.0}, {1.0}, {0.0}, {1.0}, {0.0}, {1.0}};
+  result = ValidateDynamics(state, input);
+  EXPECT_EQ(result(States::kTheta), 1.0);
+  EXPECT_LT(1.0, result(States::kOmega));
+  // Everything else should be ~zero.
+  result(States::kTheta) = 0.0;
+  result(States::kOmega) = 0.0;
+  EXPECT_LT(result.norm(), 1e-10);
+}
+
+// Tests that we can spin in place when skid-steering (i.e., all wheels stay
+// pointed straight, but we still attempt to spint he robot).
+TEST_F(SimplifiedDynamicsTest, SpinInPlaceSkidSteer) {
+  PositionState state = PositionState::Zero();
+  state(States::kThetas0) = -M_PI;
+  state(States::kThetas1) = -M_PI;
+  state(States::kThetas2) = 0.0;
+  state(States::kThetas3) = 0.0;
+  state(States::kOmega) = 1.0;
+  PositionState result = ValidateDynamics(state, Input::Zero());
+  EXPECT_EQ(result(States::kTheta), 1.0);
+  EXPECT_LT(result(States::kOmega), -1.0)
+      << "We should be aggressively decelerrating when slipping wheels.";
+  // Everything else should be ~zero.
+  result(States::kTheta) = 0.0;
+  result(States::kOmega) = 0.0;
+  EXPECT_LT(result.norm(), 1e-10);
+
+  // Sanity check that when we then apply drive torque to the wheels that that
+  // we can counteract the spin.
+  Input input{{0.0}, {100.0}, {0.0}, {100.0}, {0.0}, {100.0}, {0.0}, {100.0}};
+  result = ValidateDynamics(state, input);
+  EXPECT_EQ(result(States::kTheta), 1.0);
+  EXPECT_LT(1.0, result(States::kOmega));
+  // Everything else should be ~zero.
+  result(States::kTheta) = 0.0;
+  result(States::kOmega) = 0.0;
+  EXPECT_LT(result.norm(), 1e-10);
+}
+
+// Tests that we can spin in place when skid-steering backwards (ensures that
+// slip angle calculations and the such handle the sign changes correctly).
+TEST_F(SimplifiedDynamicsTest, SpinInPlaceSkidSteerBackwards) {
+  PositionState state = PositionState::Zero();
+  state(States::kThetas0) = 0.0;
+  state(States::kThetas1) = 0.0;
+  state(States::kThetas2) = M_PI;
+  state(States::kThetas3) = M_PI;
+  state(States::kOmega) = 1.0;
+  PositionState result = ValidateDynamics(state, Input::Zero());
+  EXPECT_EQ(result(States::kTheta), 1.0);
+  EXPECT_LT(result(States::kOmega), -1.0)
+      << "We should be aggressively decelerrating when slipping wheels.";
+  // Everything else should be ~zero.
+  result(States::kTheta) = 0.0;
+  result(States::kOmega) = 0.0;
+  EXPECT_LT(result.norm(), 1e-10);
+
+  // Sanity check that when we then apply drive torque to the wheels that that
+  // we can counteract the spin.
+  Input input{{0.0}, {-100.0}, {0.0}, {-100.0},
+              {0.0}, {-100.0}, {0.0}, {-100.0}};
+  result = ValidateDynamics(state, input);
+  EXPECT_EQ(result(States::kTheta), 1.0);
+  EXPECT_LT(1.0, result(States::kOmega));
+  // Everything else should be ~zero.
+  result(States::kTheta) = 0.0;
+  result(States::kOmega) = 0.0;
+  EXPECT_LT(result.norm(), 1e-10);
+}
+
+// Test that if we turn on the wheel alignment forces that it results in forces
+// that cause the wheels to align to straight over time.
+TEST_F(SimplifiedDynamicsTest, WheelAlignmentForces) {
+  dynamics_ = Dynamics(MakeParams(true));
+  PositionState state = PositionState::Zero();
+  state(States::kThetas0) = -0.1;
+  state(States::kThetas1) = -0.1;
+  state(States::kThetas2) = -0.1;
+  state(States::kThetas3) = -0.1;
+  state(States::kVx) = 1.0;
+  PositionState result = ValidateDynamics(state, Input::Zero());
+  EXPECT_LT(1e-2, result(States::kOmegas0));
+  EXPECT_LT(1e-2, result(States::kOmegas1));
+  EXPECT_LT(1e-2, result(States::kOmegas2));
+  EXPECT_LT(1e-2, result(States::kOmegas3));
+}
+
+// Do some fuzz testing of the jacobian calculations.
+TEST_F(SimplifiedDynamicsTest, Fuzz) {
+  for (size_t state_index = 0; state_index < States::kNumPositionStates;
+       ++state_index) {
+    SCOPED_TRACE(state_index);
+    PositionState state = PositionState::Zero();
+    for (const double value : {-1.0, 0.0, 1.0}) {
+      SCOPED_TRACE(value);
+      state(state_index) = value;
+      ValidateDynamics(state, Input::Zero());
+    }
+  }
+}
+}  // namespace frc971::control_loops::swerve::testing
diff --git a/frc971/control_loops/swerve/swerve_drivetrain_goal.fbs b/frc971/control_loops/swerve/swerve_drivetrain_goal.fbs
index eab0d0d..7cbb11c 100644
--- a/frc971/control_loops/swerve/swerve_drivetrain_goal.fbs
+++ b/frc971/control_loops/swerve/swerve_drivetrain_goal.fbs
@@ -1,3 +1,6 @@
+include "frc971/math/matrix.fbs";
+include "frc971/control_loops/swerve/swerve_drivetrain_joystick_goal.fbs";
+
 namespace frc971.control_loops.swerve;
 
 // States what translation control type goal we will care about
@@ -22,11 +25,20 @@
   translation_speed:double (id: 3);
 }
 
+attribute "static_length";
+
+table LinearVelocityGoal {
+  state:frc971.fbs.Matrix (id: 0);
+  input:frc971.fbs.Matrix (id: 1);
+}
+
 table Goal {
     front_left_goal:SwerveModuleGoal (id: 0);
     front_right_goal:SwerveModuleGoal (id: 1);
     back_left_goal:SwerveModuleGoal (id: 2);
     back_right_goal:SwerveModuleGoal (id: 3);
+    linear_velocity_goal:LinearVelocityGoal (id: 4);
+    joystick_goal:JoystickGoal (id: 5);
 }
 
 root_type Goal;
diff --git a/frc971/control_loops/swerve/swerve_drivetrain_joystick_goal.fbs b/frc971/control_loops/swerve/swerve_drivetrain_joystick_goal.fbs
new file mode 100644
index 0000000..bf44cdd
--- /dev/null
+++ b/frc971/control_loops/swerve/swerve_drivetrain_joystick_goal.fbs
@@ -0,0 +1,9 @@
+namespace frc971.control_loops.swerve;
+
+table JoystickGoal {
+  vx:float (id: 0);
+  vy:float (id: 1);
+  omega:float (id: 2);
+}
+
+root_type JoystickGoal;
diff --git a/frc971/control_loops/swerve/velocity_controller/BUILD b/frc971/control_loops/swerve/velocity_controller/BUILD
index 7d376bd..a0078b6 100644
--- a/frc971/control_loops/swerve/velocity_controller/BUILD
+++ b/frc971/control_loops/swerve/velocity_controller/BUILD
@@ -44,6 +44,8 @@
         "@pip//matplotlib",
         "@pip//numpy",
         "@pip//tensorflow",
+        "@pip//tensorflow_probability",
+        "@pip//tf_keras",
     ],
 )
 
@@ -59,6 +61,30 @@
 )
 
 py_binary(
+    name = "plot",
+    srcs = [
+        "model.py",
+        "plot.py",
+    ],
+    deps = [
+        ":experience_buffer",
+        ":physics",
+        "//frc971/control_loops/swerve:jax_dynamics",
+        "@pip//absl_py",
+        "@pip//flashbax",
+        "@pip//flax",
+        "@pip//jax",
+        "@pip//jaxtyping",
+        "@pip//matplotlib",
+        "@pip//numpy",
+        "@pip//pygobject",
+        "@pip//tensorflow",
+        "@pip//tensorflow_probability",
+        "@pip//tf_keras",
+    ],
+)
+
+py_binary(
     name = "lqr_plot",
     srcs = [
         "lqr_plot.py",
diff --git a/frc971/control_loops/swerve/velocity_controller/lqr_plot.py b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
index bd6674c..a63e100 100644
--- a/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
+++ b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
@@ -47,9 +47,26 @@
 
 # Container for the data.
 Data = collections.namedtuple('Data', [
-    't', 'X', 'X_lqr', 'U', 'U_lqr', 'cost', 'cost_lqr', 'q1_grid', 'q2_grid',
-    'q_grid', 'target_q_grid', 'lqr_grid', 'pi_grid_U', 'lqr_grid_U', 'grid_X',
-    'grid_Y', 'reward', 'reward_lqr', 'step'
+    't',
+    'X',
+    'X_lqr',
+    'U',
+    'U_lqr',
+    'cost',
+    'cost_lqr',
+    'q1_grid',
+    'q2_grid',
+    'q_grid',
+    'target_q_grid',
+    'lqr_grid',
+    'pi_grid_U',
+    'lqr_grid_U',
+    'grid_X',
+    'grid_Y',
+    'reward',
+    'reward_lqr',
+    'step',
+    'stdev',
 ])
 
 FLAGS = absl.flags.FLAGS
@@ -87,8 +104,13 @@
     rng = jax.random.key(0)
     rng, init_rng = jax.random.split(rng)
 
-    state = create_train_state(init_rng, problem, FLAGS.q_learning_rate,
-                               FLAGS.pi_learning_rate)
+    state = create_train_state(
+        init_rng,
+        problem,
+        q_learning_rate=FLAGS.q_learning_rate,
+        pi_learning_rate=FLAGS.pi_learning_rate,
+        alpha_learning_rate=FLAGS.alpha_learning_rate,
+    )
 
     state = restore_checkpoint(state, FLAGS.workdir)
     if step is not None and state.step == step:
@@ -156,12 +178,23 @@
 
     def compute_pi_U(X, Y):
         x = jax.numpy.array([X, Y])
-        U, _, _, _ = state.pi_apply(rng,
-                                    state.params,
-                                    observation=state.problem.unwrap_angles(x),
-                                    R=goal,
-                                    deterministic=True)
-        return U[0]
+        U, _, _ = state.pi_apply(rng,
+                                 state.params,
+                                 observation=state.problem.unwrap_angles(x),
+                                 R=goal,
+                                 deterministic=True)
+        return U[0] * problem.action_limit
+
+    def compute_pi_stdev(X, Y):
+        x = jax.numpy.array([X, Y])
+        _, _, std = state.pi_apply(rng,
+                                   state.params,
+                                   observation=state.problem.unwrap_angles(x),
+                                   R=goal,
+                                   deterministic=True)
+        return std[0]
+
+    std_grid = jax.vmap(jax.vmap(compute_pi_stdev))(grid_X, grid_Y)
 
     lqr_cost_U = jax.vmap(jax.vmap(compute_lqr_U))(grid_X, grid_Y)
     pi_cost_U = jax.vmap(jax.vmap(compute_pi_U))(grid_X, grid_Y)
@@ -173,22 +206,25 @@
         X, X_lqr, data, params = val
         t = data.t.at[i].set(i * problem.dt)
 
-        U, _, _, _ = state.pi_apply(rng,
-                                    params,
-                                    observation=state.problem.unwrap_angles(X),
-                                    R=goal,
-                                    deterministic=True)
+        normalized_U, _, _ = state.pi_apply(
+            rng,
+            params,
+            observation=state.problem.unwrap_angles(X),
+            R=goal,
+            deterministic=True)
         U_lqr = problem.F @ (goal - X_lqr)
 
         cost = jax.numpy.minimum(
             state.q1_apply(params,
                            observation=state.problem.unwrap_angles(X),
                            R=goal,
-                           action=U),
+                           action=normalized_U),
             state.q2_apply(params,
                            observation=state.problem.unwrap_angles(X),
                            R=goal,
-                           action=U))
+                           action=normalized_U))
+
+        U = normalized_U * problem.action_limit
 
         U_plot = data.U.at[i, :].set(U)
         U_lqr_plot = data.U_lqr.at[i, :].set(U_lqr)
@@ -200,8 +236,9 @@
         X = problem.A @ X + problem.B @ U
         X_lqr = problem.A @ X_lqr + problem.B @ U_lqr
 
-        reward = data.reward - state.problem.cost(X, U, goal)
-        reward_lqr = data.reward_lqr - state.problem.cost(X_lqr, U_lqr, goal)
+        reward = data.reward + state.problem.reward(X, normalized_U, goal)
+        reward_lqr = data.reward_lqr + state.problem.reward(
+            X_lqr, U_lqr / problem.action_limit, goal)
 
         return X, X_lqr, data._replace(
             t=t,
@@ -242,6 +279,7 @@
             reward=0.0,
             reward_lqr=0.0,
             step=state.step,
+            stdev=std_grid,
         ), X, X_lqr, state.params)
 
     logging.info('Finished integrating, reward of %f, lqr reward of %f',
@@ -265,6 +303,7 @@
         lqr_grid_U=numpy.array(data.lqr_grid_U),
         grid_X=numpy.array(data.grid_X),
         grid_Y=numpy.array(data.grid_Y),
+        stdev=numpy.array(data.stdev),
         reward=float(data.reward),
         reward_lqr=float(data.reward_lqr),
         step=data.step,
@@ -317,9 +356,10 @@
 
         self.Ufig = pyplot.figure(figsize=pyplot.figaspect(0.5))
         self.Uax = [
-            self.Ufig.add_subplot(1, 3, 1, projection='3d'),
-            self.Ufig.add_subplot(1, 3, 2, projection='3d'),
-            self.Ufig.add_subplot(1, 3, 3, projection='3d'),
+            self.Ufig.add_subplot(2, 2, 1, projection='3d'),
+            self.Ufig.add_subplot(2, 2, 2, projection='3d'),
+            self.Ufig.add_subplot(2, 2, 3, projection='3d'),
+            self.Ufig.add_subplot(2, 2, 4, projection='3d'),
         ]
 
         self.last_trajectory_step = 0
@@ -331,6 +371,7 @@
             return
         self.last_trajectory_step = data.step
         logging.info('Updating trajectory plots')
+        self.fig0.suptitle(f'Step {data.step}')
 
         # Put data in the trajectory plots.
         self.x.set_data(data.t, data.X[:, 0])
@@ -362,6 +403,7 @@
             return
         logging.info('Updating cost plots')
         self.last_cost_step = data.step
+        self.costfig.suptitle(f'Step {data.step}')
         # Put data in the cost plots.
         if hasattr(self, 'costsurf'):
             for surf in self.costsurf:
@@ -398,6 +440,7 @@
             return
         self.last_U_step = data.step
         logging.info('Updating U plots')
+        self.Ufig.suptitle(f'Step {data.step}')
         # Put data in the controller plots.
         if hasattr(self, 'Usurf'):
             for surf in self.Usurf:
@@ -407,6 +450,7 @@
             (data.lqr_grid_U, 'lqr'),
             (data.pi_grid_U, 'pi'),
             ((data.lqr_grid_U - data.pi_grid_U), 'error'),
+            (data.stdev, 'stdev'),
         ]
 
         self.Usurf = [
diff --git a/frc971/control_loops/swerve/velocity_controller/main.py b/frc971/control_loops/swerve/velocity_controller/main.py
index a0d248d..391b0fe 100644
--- a/frc971/control_loops/swerve/velocity_controller/main.py
+++ b/frc971/control_loops/swerve/velocity_controller/main.py
@@ -39,6 +39,9 @@
 
 flags.DEFINE_string('workdir', None, 'Directory to store model data.')
 
+flags.DEFINE_bool('swerve', True,
+                  'If true, train the swerve model, otherwise do the turret.')
+
 
 def main(argv):
     if len(argv) > 1:
@@ -64,7 +67,12 @@
         FLAGS.workdir,
     )
 
-    problem = physics.TurretProblem()
+    if FLAGS.swerve:
+        physics_constants = jax_dynamics.Coefficients()
+        problem = physics.SwerveProblem(physics_constants)
+    else:
+        problem = physics.TurretProblem()
+
     state = train.train(FLAGS.workdir, problem)
 
 
diff --git a/frc971/control_loops/swerve/velocity_controller/model.py b/frc971/control_loops/swerve/velocity_controller/model.py
index 1394463..7f01eef 100644
--- a/frc971/control_loops/swerve/velocity_controller/model.py
+++ b/frc971/control_loops/swerve/velocity_controller/model.py
@@ -14,9 +14,12 @@
 from jax.experimental import mesh_utils
 from jax.sharding import Mesh, PartitionSpec, NamedSharding
 from frc971.control_loops.swerve import jax_dynamics
-from frc971.control_loops.swerve import dynamics
 from frc971.control_loops.swerve.velocity_controller import physics
 from frc971.control_loops.swerve.velocity_controller import experience_buffer
+from tensorflow_probability.substrates import jax as tfp
+
+tfd = tfp.distributions
+tfb = tfp.bijectors
 
 from flax.typing import PRNGKey
 
@@ -29,6 +32,12 @@
 )
 
 absl.flags.DEFINE_float(
+    'alpha_learning_rate',
+    default=0.004,
+    help='Training learning rate for entropy.',
+)
+
+absl.flags.DEFINE_float(
     'q_learning_rate',
     default=0.002,
     help='Training learning rate.',
@@ -52,6 +61,13 @@
     help='Fraction of --pi_learning_rate to reduce by by the end.',
 )
 
+absl.flags.DEFINE_float(
+    'target_entropy_scalar',
+    default=1.0,
+    help=
+    'Target entropy scalar for use when using automatic temperature adjustment.',
+)
+
 absl.flags.DEFINE_integer(
     'replay_size',
     default=2000000,
@@ -64,6 +80,30 @@
     help='Batch size for learning Q and pi',
 )
 
+absl.flags.DEFINE_boolean(
+    'skip_layer',
+    default=False,
+    help='If true, add skip layer connections to the Q network.',
+)
+
+absl.flags.DEFINE_boolean(
+    'rmsnorm',
+    default=False,
+    help='If true, use rmsnorm instead of layer norm.',
+)
+
+absl.flags.DEFINE_boolean(
+    'dreamer_solver',
+    default=False,
+    help='If true, use the solver from dreamer v3 instead of adam.',
+)
+
+absl.flags.DEFINE_float(
+    'initial_logalpha',
+    default=0.0,
+    help='The initial value to set logalpha to.',
+)
+
 HIDDEN_WEIGHTS = 256
 
 LOG_STD_MIN = -20
@@ -127,36 +167,25 @@
         if rng is None:
             rng = self.make_rng('pi')
 
-        # Grab a random sample
-        random_sample = jax.random.normal(rng, shape=std.shape)
+        pi_distribution = tfd.TransformedDistribution(
+            distribution=tfd.Normal(loc=mu, scale=std),
+            bijector=tfb.Tanh(),
+        )
 
         if deterministic:
             # We are testing the optimal policy, just use the mean.
-            pi_action = mu
+            pi_action = flax.linen.activation.tanh(mu)
         else:
-            # Use the reparameterization trick.  Adjust the unit gausian with
-            # something we can solve for to get the desired noise.
-            pi_action = random_sample * std + mu
+            pi_action = pi_distribution.sample(shape=(1, ), seed=rng)
 
-        logp_pi = gaussian_likelihood(random_sample, log_std)
-        # Adjustment to log prob
-        # NOTE: This formula is a little bit magic. To get an understanding of where it
-        # comes from, check out the original SAC paper (arXiv 1801.01290) and look in
-        # appendix C. This is a more numerically-stable equivalent to Eq 21.
-        delta = (2.0 * (jax.numpy.log(2.0) - pi_action -
-                        flax.linen.softplus(-2.0 * pi_action)))
+        logp_pi = pi_distribution.log_prob(pi_action)
 
-        if len(delta.shape) > 1:
-            delta = jax.numpy.sum(delta, keepdims=True, axis=1)
+        if len(logp_pi.shape) > 1:
+            logp_pi = jax.numpy.sum(logp_pi, keepdims=True, axis=1)
         else:
-            delta = jax.numpy.sum(delta, keepdims=True)
+            logp_pi = jax.numpy.sum(logp_pi, keepdims=True)
 
-        logp_pi = logp_pi - delta
-
-        # Now, saturate the action to the limit using tanh
-        pi_action = self.action_limit * flax.linen.activation.tanh(pi_action)
-
-        return pi_action, logp_pi, self.action_limit * std, random_sample
+        return pi_action, logp_pi, self.action_limit * std
 
 
 class MLPQFunction(nn.Module):
@@ -173,10 +202,23 @@
         # Estimate Q with a simple multi layer dense network.
         x = jax.numpy.hstack((observation, R, action))
         for i, hidden_size in enumerate(self.hidden_sizes):
+            # Add d2rl skip layer connections if requested.
+            # Idea from D2RL: https://arxiv.org/pdf/2010.09163.
+            if FLAGS.skip_layer and i != 0:
+                x = jax.numpy.hstack((x, observation, R, action))
+
             x = nn.Dense(
                 name=f'denselayer{i}',
                 features=hidden_size,
             )(x)
+
+            if FLAGS.rmsnorm:
+                # Idea from Dreamerv3: https://arxiv.org/pdf/2301.04104v2.
+                x = nn.RMSNorm(name=f'rmsnorm{i}')(x)
+            else:
+                # Layernorm also improves stability.
+                # Idea from RLPD: https://arxiv.org/pdf/2302.02948.
+                x = nn.LayerNorm(name=f'layernorm{i}')(x)
             x = self.activation(x)
 
         x = nn.Dense(
@@ -369,7 +411,7 @@
             q_opt_state=q_opt_state,
             alpha_tx=alpha_tx,
             alpha_opt_state=alpha_opt_state,
-            target_entropy=-problem.num_states,
+            target_entropy=-problem.num_outputs * FLAGS.target_entropy_scalar,
             mesh=mesh,
             sharding=sharding,
             replicated_sharding=replicated_sharding,
@@ -380,14 +422,15 @@
 
 
 def create_train_state(rng: PRNGKey, problem: Problem, q_learning_rate,
-                       pi_learning_rate):
+                       pi_learning_rate, alpha_learning_rate):
     """Creates initial `TrainState`."""
-    pi = SquashedGaussianMLPActor(activation=nn.activation.gelu,
+    pi = SquashedGaussianMLPActor(activation=nn.activation.silu,
                                   action_space=problem.num_outputs,
                                   action_limit=problem.action_limit)
     # We want q1 and q2 to have different network architectures so they pick up differnet things.
-    q1 = MLPQFunction(activation=nn.activation.gelu, hidden_sizes=[128, 256])
-    q2 = MLPQFunction(activation=nn.activation.gelu, hidden_sizes=[256, 128])
+    # SiLu is used in DreamerV3 so we use it: https://arxiv.org/pdf/2301.04104v2.
+    q1 = MLPQFunction(activation=nn.activation.silu, hidden_sizes=[128, 256])
+    q2 = MLPQFunction(activation=nn.activation.silu, hidden_sizes=[256, 128])
 
     @jax.jit
     def init_params(rng):
@@ -412,7 +455,7 @@
         )['params']
 
         if FLAGS.alpha < 0.0:
-            logalpha = 0.0
+            logalpha = FLAGS.initial_logalpha
         else:
             logalpha = jax.numpy.log(FLAGS.alpha)
 
@@ -423,9 +466,14 @@
             'logalpha': logalpha,
         }
 
-    pi_tx = optax.sgd(learning_rate=pi_learning_rate)
-    q_tx = optax.sgd(learning_rate=q_learning_rate)
-    alpha_tx = optax.sgd(learning_rate=q_learning_rate)
+    if FLAGS.dreamer_solver:
+        pi_tx = create_dreamer_solver(learning_rate=pi_learning_rate)
+        q_tx = create_dreamer_solver(learning_rate=q_learning_rate)
+        alpha_tx = create_dreamer_solver(learning_rate=alpha_learning_rate)
+    else:
+        pi_tx = optax.adam(learning_rate=pi_learning_rate)
+        q_tx = optax.adam(learning_rate=q_learning_rate)
+        alpha_tx = optax.adam(learning_rate=alpha_learning_rate)
 
     result = TrainState.create(
         problem=problem,
@@ -441,6 +489,89 @@
     return result
 
 
+# Solver from dreamer v3: https://arxiv.org/pdf/2301.04104v2.
+# TODO(austin): How many of these pieces are actually in optax already?
+def scale_by_rms(beta=0.999, eps=1e-8):
+
+    def init_fn(params):
+        nu = jax.tree_util.tree_map(
+            lambda t: jax.numpy.zeros_like(t, jax.numpy.float32), params)
+        step = jax.numpy.zeros((), jax.numpy.int32)
+        return (step, nu)
+
+    def update_fn(updates, state, params=None):
+        step, nu = state
+        step = optax.safe_int32_increment(step)
+        nu = jax.tree_util.tree_map(
+            lambda v, u: beta * v + (1 - beta) * (u * u), nu, updates)
+        nu_hat = optax.bias_correction(nu, beta, step)
+        updates = jax.tree_util.tree_map(
+            lambda u, v: u / (jax.numpy.sqrt(v) + eps), updates, nu_hat)
+        return updates, (step, nu)
+
+    return optax.GradientTransformation(init_fn, update_fn)
+
+
+def scale_by_agc(clip=0.03, pmin=1e-3):
+
+    def init_fn(params):
+        return ()
+
+    def update_fn(updates, state, params=None):
+
+        def fn(param, update):
+            unorm = jax.numpy.linalg.norm(update.flatten(), 2)
+            pnorm = jax.numpy.linalg.norm(param.flatten(), 2)
+            upper = clip * jax.numpy.maximum(pmin, pnorm)
+            return update * (1 / jax.numpy.maximum(1.0, unorm / upper))
+
+        updates = jax.tree_util.tree_map(fn, params, updates)
+        return updates, ()
+
+    return optax.GradientTransformation(init_fn, update_fn)
+
+
+def scale_by_momentum(beta=0.9, nesterov=False):
+
+    def init_fn(params):
+        mu = jax.tree_util.tree_map(
+            lambda t: jax.numpy.zeros_like(t, jax.numpy.float32), params)
+        step = jax.numpy.zeros((), jax.numpy.int32)
+        return (step, mu)
+
+    def update_fn(updates, state, params=None):
+        step, mu = state
+        step = optax.safe_int32_increment(step)
+        mu = optax.update_moment(updates, mu, beta, 1)
+        if nesterov:
+            mu_nesterov = optax.update_moment(updates, mu, beta, 1)
+            mu_hat = optax.bias_correction(mu_nesterov, beta, step)
+        else:
+            mu_hat = optax.bias_correction(mu, beta, step)
+        return mu_hat, (step, mu)
+
+    return optax.GradientTransformation(init_fn, update_fn)
+
+
+def create_dreamer_solver(
+    learning_rate,
+    agc: float = 0.3,
+    pmin: float = 1e-3,
+    beta1: float = 0.9,
+    beta2: float = 0.999,
+    eps: float = 1e-20,
+    nesterov: bool = False,
+) -> optax.base.GradientTransformation:
+    # From dreamer v3.
+    return optax.chain(
+        # Adaptive gradient clipping.
+        scale_by_agc(agc, pmin),
+        scale_by_rms(beta2, eps),
+        scale_by_momentum(beta1, nesterov),
+        optax.scale_by_learning_rate(learning_rate),
+    )
+
+
 def create_learning_rate_fn(
     base_learning_rate: float,
     final_learning_rate: float,
diff --git a/frc971/control_loops/swerve/velocity_controller/physics.py b/frc971/control_loops/swerve/velocity_controller/physics.py
index b0da0b4..accd0d6 100644
--- a/frc971/control_loops/swerve/velocity_controller/physics.py
+++ b/frc971/control_loops/swerve/velocity_controller/physics.py
@@ -55,13 +55,17 @@
     def random_states(self, rng: PRNGKey, dimensions=None):
         raise NotImplemented("random_states not implemented")
 
-    def random_actions(self, rng: PRNGKey, dimensions=None):
+    def random_actions(self,
+                       rng: PRNGKey,
+                       X: jax.typing.ArrayLike,
+                       goal: jax.typing.ArrayLike,
+                       dimensions=None):
         """Produces a uniformly random action in the action space."""
         return jax.random.uniform(
             rng,
             (dimensions or FLAGS.num_agents, self.num_outputs),
-            minval=-self.action_limit,
-            maxval=self.action_limit,
+            minval=-1.0,
+            maxval=1.0,
         )
 
     def random_goals(self, rng: PRNGKey, dimensions=None):
@@ -94,12 +98,13 @@
         A_continuous = jax.numpy.array([[0., 1.], [0., -36.85154548]])
         B_continuous = jax.numpy.array([[0.], [56.08534375]])
 
+        U = U * self.action_limit
         return A_continuous @ X + B_continuous @ U
 
-    def cost(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike,
-             goal: jax.typing.ArrayLike):
-        return (X - goal).T @ jax.numpy.array(
-            self.Q) @ (X - goal) + U.T @ jax.numpy.array(self.R) @ U
+    def reward(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike,
+               goal: jax.typing.ArrayLike):
+        return -(X - goal).T @ jax.numpy.array(
+            self.Q) @ (X - goal) - U.T @ jax.numpy.array(self.R) @ U
 
     def random_states(self, rng: PRNGKey, dimensions=None):
         rng1, rng2 = jax.random.split(rng)
@@ -120,3 +125,123 @@
                                maxval=0.1),
             jax.numpy.zeros((dimensions or FLAGS.num_agents, 1)),
         ))
+
+
+class SwerveProblem(Problem):
+
+    def __init__(self, coefficients: jax_dynamics.CoefficientsType):
+        super().__init__(num_states=jax_dynamics.NUM_VELOCITY_STATES,
+                         num_unwrapped_states=17,
+                         num_outputs=8,
+                         num_goals=3,
+                         action_limit=40.0)
+
+        self.coefficients = coefficients
+
+    def random_actions(self,
+                       rng: PRNGKey,
+                       X: jax.typing.ArrayLike,
+                       goal: jax.typing.ArrayLike,
+                       dimensions=None):
+        """Produces a uniformly random action in the action space."""
+        return jax.random.uniform(
+            rng,
+            (dimensions or FLAGS.num_agents, self.num_outputs),
+            minval=-1.0,
+            maxval=1.0,
+        )
+
+    def unwrap_angles(self, X: jax.typing.ArrayLike):
+        return jax.numpy.stack([
+            jax.numpy.cos(X[..., jax_dynamics.VELOCITY_STATE_THETAS0]),
+            jax.numpy.sin(X[..., jax_dynamics.VELOCITY_STATE_THETAS0]),
+            X[..., jax_dynamics.VELOCITY_STATE_OMEGAS0],
+            jax.numpy.cos(X[..., jax_dynamics.VELOCITY_STATE_THETAS1]),
+            jax.numpy.sin(X[..., jax_dynamics.VELOCITY_STATE_THETAS1]),
+            X[..., jax_dynamics.VELOCITY_STATE_OMEGAS1],
+            jax.numpy.cos(X[..., jax_dynamics.VELOCITY_STATE_THETAS2]),
+            jax.numpy.sin(X[..., jax_dynamics.VELOCITY_STATE_THETAS2]),
+            X[..., jax_dynamics.VELOCITY_STATE_OMEGAS2],
+            jax.numpy.cos(X[..., jax_dynamics.VELOCITY_STATE_THETAS3]),
+            jax.numpy.sin(X[..., jax_dynamics.VELOCITY_STATE_THETAS3]),
+            X[..., jax_dynamics.VELOCITY_STATE_OMEGAS3],
+            jax.numpy.cos(X[..., jax_dynamics.VELOCITY_STATE_THETA]),
+            jax.numpy.sin(X[..., jax_dynamics.VELOCITY_STATE_THETA]),
+            X[..., jax_dynamics.VELOCITY_STATE_VX],
+            X[..., jax_dynamics.VELOCITY_STATE_VY],
+            X[..., jax_dynamics.VELOCITY_STATE_OMEGA],
+        ],
+                               axis=-1)
+
+    def xdot(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike):
+        return jax_dynamics.velocity_dynamics(self.coefficients, X,
+                                              self.action_limit * U)
+
+    def reward(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike,
+               goal: jax.typing.ArrayLike):
+        return -jax_dynamics.mpc_cost(coefficients=self.coefficients,
+                                      X=X,
+                                      U=self.action_limit * U,
+                                      goal=goal)
+
+    def random_states(self, rng: PRNGKey, dimensions=None):
+        rng, rng1, rng2, rng3, rng4, rng5, rng6, rng7, rng8, rng9, rng10, rng11 = jax.random.split(
+            rng, num=12)
+
+        return jax.numpy.hstack((
+            # VELOCITY_STATE_THETAS0 = 0
+            self._random_angle(rng1, dimensions),
+            # VELOCITY_STATE_OMEGAS0 = 1
+            self._random_module_velocity(rng2, dimensions),
+            # VELOCITY_STATE_THETAS1 = 2
+            self._random_angle(rng3, dimensions),
+            # VELOCITY_STATE_OMEGAS1 = 3
+            self._random_module_velocity(rng4, dimensions),
+            # VELOCITY_STATE_THETAS2 = 4
+            self._random_angle(rng5, dimensions),
+            # VELOCITY_STATE_OMEGAS2 = 5
+            self._random_module_velocity(rng6, dimensions),
+            # VELOCITY_STATE_THETAS3 = 6
+            self._random_angle(rng7, dimensions),
+            # VELOCITY_STATE_OMEGAS3 = 7
+            self._random_module_velocity(rng8, dimensions),
+            # VELOCITY_STATE_THETA = 8
+            self._random_angle(rng9, dimensions),
+            # VELOCITY_STATE_VX = 9
+            # VELOCITY_STATE_VY = 10
+            self._random_robot_velocity(rng10, dimensions),
+            # VELOCITY_STATE_OMEGA = 11
+            self._random_robot_angular_velocity(rng11, dimensions),
+        ))
+
+    def random_goals(self, rng: PRNGKey, dimensions=None):
+        """Produces a random goal in the goal space."""
+        return jax.numpy.hstack((
+            jax.random.uniform(rng, (dimensions or FLAGS.num_agents, 1),
+                               minval=1.0,
+                               maxval=1.0),
+            jax.numpy.zeros((dimensions or FLAGS.num_agents, 2)),
+        ))
+
+    MODULE_VELOCITY = 1.0
+    ROBOT_ANGULAR_VELOCITY = 0.5
+
+    def _random_angle(self, rng: PRNGKey, dimensions=None):
+        return jax.random.uniform(rng, (dimensions or FLAGS.num_agents, 1),
+                                  minval=-0.1 * jax.numpy.pi,
+                                  maxval=0.1 * jax.numpy.pi)
+
+    def _random_module_velocity(self, rng: PRNGKey, dimensions=None):
+        return jax.random.uniform(rng, (dimensions or FLAGS.num_agents, 1),
+                                  minval=-self.MODULE_VELOCITY,
+                                  maxval=self.MODULE_VELOCITY)
+
+    def _random_robot_velocity(self, rng: PRNGKey, dimensions=None):
+        return jax.random.uniform(rng, (dimensions or FLAGS.num_agents, 2),
+                                  minval=0.9,
+                                  maxval=1.1)
+
+    def _random_robot_angular_velocity(self, rng: PRNGKey, dimensions=None):
+        return jax.random.uniform(rng, (dimensions or FLAGS.num_agents, 1),
+                                  minval=-self.ROBOT_ANGULAR_VELOCITY,
+                                  maxval=self.ROBOT_ANGULAR_VELOCITY)
diff --git a/frc971/control_loops/swerve/velocity_controller/plot.py b/frc971/control_loops/swerve/velocity_controller/plot.py
new file mode 100644
index 0000000..502a3e2
--- /dev/null
+++ b/frc971/control_loops/swerve/velocity_controller/plot.py
@@ -0,0 +1,430 @@
+#!/usr/bin/env python3
+
+import os
+
+os.environ['XLA_FLAGS'] = ' '.join([
+    # Teach it where to find CUDA
+    '--xla_gpu_cuda_data_dir=/usr/lib/cuda',
+    # Use up to 20 cores
+    '--xla_force_host_platform_device_count=20',
+    # Dump XLA to /tmp/foo to aid debugging
+    #'--xla_dump_to=/tmp/foo',
+    #'--xla_gpu_enable_command_buffer='
+])
+
+os.environ['JAX_PLATFORMS'] = 'cpu'
+# Don't pre-allocate memory
+os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
+
+import absl
+from absl import logging
+from matplotlib.animation import FuncAnimation
+import matplotlib
+import numpy
+import scipy
+import time
+
+matplotlib.use("gtk3agg")
+
+from matplotlib import pylab
+from matplotlib import pyplot
+from flax.training import checkpoints
+import tensorflow as tf
+from jax.experimental.ode import odeint
+import threading
+import collections
+
+from frc971.control_loops.swerve.velocity_controller.model import *
+from frc971.control_loops.swerve.velocity_controller.physics import *
+
+# Container for the data.
+Data = collections.namedtuple('Data', [
+    't', 'X', 'U', 'logp_pi', 'cost', 'q1_grid', 'q2_grid', 'q_grid',
+    'target_q_grid', 'pi_grid_U', 'grid_X', 'grid_Y', 'reward', 'step',
+    'rewards'
+])
+
+FLAGS = absl.flags.FLAGS
+
+absl.flags.DEFINE_string('workdir', None, 'Directory to store model data.')
+
+absl.flags.DEFINE_integer(
+    'horizon',
+    default=100,
+    help='MPC horizon',
+)
+
+numpy.set_printoptions(linewidth=200, )
+
+absl.flags.DEFINE_float(
+    'alpha',
+    default=0.2,
+    help='Entropy.  If negative, automatically solve for it.',
+)
+
+
+def restore_checkpoint(state: TrainState, workdir: str):
+    return checkpoints.restore_checkpoint(workdir, state)
+
+
+dt = 0.005
+
+
+def X0Full():
+    module_theta = 0.0
+    module_omega = 0.0
+    theta = 0.0
+    vx = 1.0
+    dtheta = 0.05
+    vy = 0.0
+    drive_omega = jax.numpy.hypot(vx, vy) / jax_dynamics.WHEEL_RADIUS
+    omega = 0.0
+    return jax.numpy.array([
+        module_theta + dtheta,
+        0.0,
+        module_omega,
+        drive_omega,
+        module_theta + dtheta,
+        0.0,
+        module_omega,
+        drive_omega,
+        module_theta - dtheta,
+        0.0,
+        module_omega,
+        drive_omega,
+        module_theta - dtheta,
+        0.0,
+        module_omega,
+        drive_omega,
+        0.0,
+        0.0,
+        theta,
+        vx,
+        vy,
+        omega,
+        0.0,
+        0.0,
+        0.0,
+    ])
+
+
+def generate_data(step=None):
+    grid_X = numpy.arange(-1, 1, 0.1)
+    grid_Y = numpy.arange(-10, 10, 0.1)
+    grid_X, grid_Y = numpy.meshgrid(grid_X, grid_Y)
+    grid_X = jax.numpy.array(grid_X)
+    grid_Y = jax.numpy.array(grid_Y)
+    # Load the training state.
+    physics_constants = jax_dynamics.Coefficients()
+    problem = physics.SwerveProblem(physics_constants)
+
+    rng = jax.random.key(0)
+    rng, init_rng = jax.random.split(rng)
+
+    state = create_train_state(init_rng,
+                               problem,
+                               FLAGS.q_learning_rate,
+                               FLAGS.pi_learning_rate,
+                               alpha_learning_rate=0.001)
+
+    state = restore_checkpoint(state, FLAGS.workdir)
+
+    X = X0Full()
+    X_lqr = X.copy()
+    goal = jax.numpy.array([1.0, 0.0, 0.0])
+
+    logging.info('X: %s', X)
+    logging.info('goal: %s', goal)
+    logging.debug('params: %s', state.params)
+
+    # Now simulate the robot, accumulating up things to plot.
+    def loop(i, val):
+        X, data, params = val
+        t = data.t.at[i].set(i * problem.dt)
+
+        U, logp_pi, std = state.pi_apply(
+            rng,
+            params,
+            observation=state.problem.unwrap_angles(
+                jax_dynamics.to_velocity_state(X)),
+            R=goal,
+            deterministic=True)
+
+        logp_pi = logp_pi * jax.numpy.exp(params['logalpha'])
+
+        jax.debug.print('mu: {mu} std: {std}', mu=U, std=std)
+
+        step_reward = state.problem.q_reward(jax_dynamics.to_velocity_state(X),
+                                             U, goal)
+        reward = data.reward + step_reward
+
+        cost = jax.numpy.minimum(
+            state.q1_apply(params,
+                           observation=state.problem.unwrap_angles(
+                               jax_dynamics.to_velocity_state(X)),
+                           R=goal,
+                           action=U),
+            state.q2_apply(params,
+                           observation=state.problem.unwrap_angles(
+                               jax_dynamics.to_velocity_state(X)),
+                           R=goal,
+                           action=U))
+
+        U = U * problem.action_limit
+        U_plot = data.U.at[i, :].set(U)
+        rewards = data.rewards.at[i, :].set(step_reward)
+        X_plot = data.X.at[i, :].set(X)
+        cost_plot = data.cost.at[i, :].set(cost)
+        logp_pi_plot = data.logp_pi.at[i, :].set(logp_pi)
+
+        # TODO(austin): I'd really like to visualize the slip angle per wheel.
+        # Maybe also the force deviation, etc.
+        # I think that would help enormously in figuring out how good a specific solution is.
+
+        def fn(X, t):
+            return jax_dynamics.full_dynamics(problem.coefficients, X,
+                                              U).flatten()
+
+        X = odeint(fn, X, jax.numpy.array([0, problem.dt]))
+
+        return X[1, :], data._replace(
+            t=t,
+            U=U_plot,
+            X=X_plot,
+            logp_pi=logp_pi_plot,
+            cost=cost_plot,
+            reward=reward,
+            rewards=rewards,
+        ), params
+
+    # Do it.
+    @jax.jit
+    def integrate(data, X, params):
+        return jax.lax.fori_loop(0, FLAGS.horizon, loop, (X, data, params))
+
+    X, data, params = integrate(
+        Data(
+            t=jax.numpy.zeros((FLAGS.horizon, )),
+            X=jax.numpy.zeros((FLAGS.horizon, jax_dynamics.NUM_STATES)),
+            U=jax.numpy.zeros((FLAGS.horizon, state.problem.num_outputs)),
+            logp_pi=jax.numpy.zeros((FLAGS.horizon, 1)),
+            rewards=jax.numpy.zeros((FLAGS.horizon, 1)),
+            cost=jax.numpy.zeros((FLAGS.horizon, 1)),
+            q1_grid=jax.numpy.zeros(grid_X.shape),
+            q2_grid=jax.numpy.zeros(grid_X.shape),
+            q_grid=jax.numpy.zeros(grid_X.shape),
+            target_q_grid=jax.numpy.zeros(grid_X.shape),
+            pi_grid_U=jax.numpy.zeros(grid_X.shape),
+            grid_X=grid_X,
+            grid_Y=grid_Y,
+            reward=0.0,
+            step=state.step,
+        ), X, state.params)
+
+    logging.info('Reward: %s', float(data.reward))
+
+    # Convert back to numpy for plotting.
+    return Data(
+        t=numpy.array(data.t),
+        X=numpy.array(data.X),
+        U=numpy.array(data.U),
+        logp_pi=numpy.array(data.logp_pi),
+        cost=numpy.array(data.cost),
+        q1_grid=numpy.array(data.q1_grid),
+        q2_grid=numpy.array(data.q2_grid),
+        q_grid=numpy.array(data.q_grid),
+        target_q_grid=numpy.array(data.target_q_grid),
+        pi_grid_U=numpy.array(data.pi_grid_U),
+        grid_X=numpy.array(data.grid_X),
+        grid_Y=numpy.array(data.grid_Y),
+        rewards=numpy.array(data.rewards),
+        reward=float(data.reward),
+        step=data.step,
+    )
+
+
+class Plotter(object):
+
+    def __init__(self, data):
+        # Make all the plots and axis.
+        self.fig0, self.axs0 = pylab.subplots(3)
+        self.fig0.supxlabel('Seconds')
+
+        self.vx, = self.axs0[0].plot([], [], label="vx")
+        self.vy, = self.axs0[0].plot([], [], label="vy")
+        self.omega, = self.axs0[0].plot([], [], label="omega")
+        self.axs0[0].set_ylabel('Velocity')
+        self.axs0[0].legend()
+        self.axs0[0].grid()
+
+        self.steer0, = self.axs0[1].plot([], [], label="Steer0")
+        self.steer1, = self.axs0[1].plot([], [], label="Steer1")
+        self.steer2, = self.axs0[1].plot([], [], label="Steer2")
+        self.steer3, = self.axs0[1].plot([], [], label="Steer3")
+        self.axs0[1].set_ylabel('Amps')
+        self.axs0[1].legend()
+        self.axs0[1].grid()
+
+        self.drive0, = self.axs0[2].plot([], [], label="Drive0")
+        self.drive1, = self.axs0[2].plot([], [], label="Drive1")
+        self.drive2, = self.axs0[2].plot([], [], label="Drive2")
+        self.drive3, = self.axs0[2].plot([], [], label="Drive3")
+        self.axs0[2].set_ylabel('Amps')
+        self.axs0[2].legend()
+        self.axs0[2].grid()
+
+        self.fig1, self.axs1 = pylab.subplots(3)
+        self.fig1.supxlabel('Seconds')
+
+        self.theta0, = self.axs1[0].plot([], [], label='steer position0')
+        self.theta1, = self.axs1[0].plot([], [], label='steer position1')
+        self.theta2, = self.axs1[0].plot([], [], label='steer position2')
+        self.theta3, = self.axs1[0].plot([], [], label='steer position3')
+        self.axs1[0].set_ylabel('Radians')
+        self.axs1[0].legend()
+        self.omega0, = self.axs1[1].plot([], [], label='steer velocity0')
+        self.omega1, = self.axs1[1].plot([], [], label='steer velocity1')
+        self.omega2, = self.axs1[1].plot([], [], label='steer velocity2')
+        self.omega3, = self.axs1[1].plot([], [], label='steer velocity3')
+        self.axs1[1].set_ylabel('Radians/second')
+        self.axs1[1].legend()
+
+        self.logp_axis = self.axs1[2].twinx()
+        self.cost, = self.axs1[2].plot([], [], label='cost')
+        self.reward, = self.axs1[2].plot([], [], label='reward')
+        self.axs1[2].set_ylabel('Radians/second')
+        self.axs1[2].legend()
+
+        self.logp_pi, = self.logp_axis.plot([], [],
+                                            label='logp_pi',
+                                            color='C2')
+        self.logp_axis.set_ylabel('log(liklihood)*alpha')
+        self.logp_axis.legend()
+
+        self.last_robot_step = 0
+        self.last_steer_step = 0
+
+    def update_robot_plot(self, data):
+        if data.step is not None and data.step == self.last_robot_step:
+            return
+        self.last_robot_step = data.step
+        logging.info('Updating robot plots')
+        self.fig0.suptitle(f'Step {data.step}')
+
+        self.vx.set_data(data.t, data.X[:, jax_dynamics.STATE_VX])
+        self.vy.set_data(data.t, data.X[:, jax_dynamics.STATE_VY])
+        self.omega.set_data(data.t, data.X[:, jax_dynamics.STATE_OMEGA])
+
+        self.axs0[0].relim()
+        self.axs0[0].autoscale_view()
+
+        self.steer0.set_data(data.t, data.U[:, 0])
+        self.steer1.set_data(data.t, data.U[:, 2])
+        self.steer2.set_data(data.t, data.U[:, 4])
+        self.steer3.set_data(data.t, data.U[:, 6])
+        self.axs0[1].relim()
+        self.axs0[1].autoscale_view()
+
+        self.drive0.set_data(data.t, data.U[:, 1])
+        self.drive1.set_data(data.t, data.U[:, 3])
+        self.drive2.set_data(data.t, data.U[:, 5])
+        self.drive3.set_data(data.t, data.U[:, 7])
+        self.axs0[2].relim()
+        self.axs0[2].autoscale_view()
+
+        return (self.vx, self.vy, self.omega, self.steer0, self.steer1,
+                self.steer2, self.steer3, self.drive0, self.drive1,
+                self.drive2, self.drive3)
+
+    def update_steer_plot(self, data):
+        if data.step == self.last_steer_step:
+            return
+        self.last_steer_step = data.step
+        logging.info('Updating steer plots')
+        self.fig1.suptitle(f'Step {data.step}')
+
+        self.theta0.set_data(data.t, data.X[:, jax_dynamics.STATE_THETAS0])
+        self.theta1.set_data(data.t, data.X[:, jax_dynamics.STATE_THETAS1])
+        self.theta2.set_data(data.t, data.X[:, jax_dynamics.STATE_THETAS2])
+        self.theta3.set_data(data.t, data.X[:, jax_dynamics.STATE_THETAS3])
+        self.axs1[0].relim()
+        self.axs1[0].autoscale_view()
+
+        self.omega0.set_data(data.t, data.X[:, jax_dynamics.STATE_OMEGAS0])
+        self.omega1.set_data(data.t, data.X[:, jax_dynamics.STATE_OMEGAS1])
+        self.omega2.set_data(data.t, data.X[:, jax_dynamics.STATE_OMEGAS2])
+        self.omega3.set_data(data.t, data.X[:, jax_dynamics.STATE_OMEGAS3])
+        self.axs1[1].relim()
+        self.axs1[1].autoscale_view()
+
+        self.cost.set_data(data.t, data.cost)
+        self.reward.set_data(data.t, data.rewards)
+        self.logp_pi.set_data(data.t, data.logp_pi)
+        self.axs1[2].relim()
+        self.axs1[2].autoscale_view()
+        self.logp_axis.relim()
+        self.logp_axis.autoscale_view()
+
+        return (self.theta0, self.theta1, self.theta2, self.theta3,
+                self.omega0, self.omega1, self.omega2, self.omega3, self.cost,
+                self.logp_pi, self.reward)
+
+
+def main(argv):
+    if len(argv) > 1:
+        raise absl.app.UsageError('Too many command-line arguments.')
+
+    tf.config.experimental.set_visible_devices([], 'GPU')
+
+    lock = threading.Lock()
+
+    # Load data.
+    data = generate_data()
+
+    plotter = Plotter(data)
+
+    # Event for shutting down the thread.
+    shutdown = threading.Event()
+
+    # Thread to grab new data periodically.
+    def do_update():
+        while True:
+            nonlocal data
+
+            my_data = generate_data(data.step)
+
+            if my_data is not None:
+                with lock:
+                    data = my_data
+
+            if shutdown.wait(timeout=3):
+                return
+
+    update_thread = threading.Thread(target=do_update)
+    update_thread.start()
+
+    # Now, update each of the plots every second with the new data.
+    def update0(frame):
+        with lock:
+            my_data = data
+
+        return plotter.update_robot_plot(my_data)
+
+    def update1(frame):
+        with lock:
+            my_data = data
+
+        return plotter.update_steer_plot(my_data)
+
+    animation0 = FuncAnimation(plotter.fig0, update0, interval=1000)
+    animation1 = FuncAnimation(plotter.fig1, update1, interval=1000)
+
+    pyplot.show()
+
+    shutdown.set()
+    update_thread.join()
+
+
+if __name__ == '__main__':
+    absl.flags.mark_flags_as_required(['workdir'])
+    absl.app.run(main)
diff --git a/frc971/control_loops/swerve/velocity_controller/train.py b/frc971/control_loops/swerve/velocity_controller/train.py
index cf54bc4..34d904a 100644
--- a/frc971/control_loops/swerve/velocity_controller/train.py
+++ b/frc971/control_loops/swerve/velocity_controller/train.py
@@ -1,3 +1,6 @@
+# Machine learning based on Soft Actor Critic(SAC) which was initially proposed in https://arxiv.org/pdf/1801.01290.
+# Our implementation was heavily based on OpenAI's spinning up reference implementation https://spinningup.openai.com/en/latest/algorithms/sac.html.
+
 import absl
 import time
 import collections
@@ -35,7 +38,7 @@
 )
 
 absl.flags.DEFINE_integer(
-    'start_steps',
+    'random_sample_steps',
     default=10000,
     help='Number of steps to randomly sample before using the policy',
 )
@@ -76,6 +79,13 @@
     help='If true, explode on any NaNs found, and print them.',
 )
 
+absl.flags.DEFINE_bool(
+    'maximum_entropy_q',
+    default=True,
+    help=
+    'If false, do not add the maximum entropy term to the bellman backup for Q.',
+)
+
 
 def save_checkpoint(state: TrainState, workdir: str):
     """Saves a checkpoint in the workdir."""
@@ -123,10 +133,10 @@
     R = data['goals']
 
     # Compute the ending actions from the current network.
-    action2, logp_pi2, _, _ = state.pi_apply(rng=rng,
-                                             params=params,
-                                             observation=observations2,
-                                             R=R)
+    action2, logp_pi2, _ = state.pi_apply(rng=rng,
+                                          params=params,
+                                          observation=observations2,
+                                          R=R)
 
     # Compute target network Q values
     q1_pi_target = state.q1_apply(state.target_params,
@@ -142,8 +152,13 @@
     alpha = jax.numpy.exp(params['logalpha'])
 
     # Now we can compute the Bellman backup
-    bellman_backup = jax.lax.stop_gradient(rewards + FLAGS.gamma *
-                                           (q_pi_target - alpha * logp_pi2))
+    # Max entropy SAC is based on https://arxiv.org/pdf/1812.05905.
+    if FLAGS.maximum_entropy_q:
+        bellman_backup = jax.lax.stop_gradient(
+            rewards + FLAGS.gamma * (q_pi_target - alpha * logp_pi2))
+    else:
+        bellman_backup = jax.lax.stop_gradient(rewards +
+                                               FLAGS.gamma * q_pi_target)
 
     # Compute the starting Q values from the Q network being optimized.
     q1 = state.q1_apply(params, observation=observations1, R=R, action=actions)
@@ -156,19 +171,6 @@
 
 
 @jax.jit
-def compute_batched_loss_q(state: TrainState, rng: PRNGKey, params,
-                           data: ArrayLike):
-
-    def bound_compute_loss_q(rng, data):
-        return compute_loss_q(state, rng, params, data)
-
-    return jax.vmap(bound_compute_loss_q)(
-        jax.random.split(rng, FLAGS.num_agents),
-        data,
-    ).mean()
-
-
-@jax.jit
 def compute_loss_pi(state: TrainState, rng: PRNGKey, params, data: ArrayLike):
     """Computes the Soft Actor-Critic loss for pi."""
     observations1 = data['observations1']
@@ -176,10 +178,10 @@
     # TODO(austin): We've got differentiable policy and differentiable physics.  Can we use those here?  Have Q learn the future, not the current step?
 
     # Compute the action
-    pi, logp_pi, _, _ = state.pi_apply(rng=rng,
-                                       params=params,
-                                       observation=observations1,
-                                       R=R)
+    pi, logp_pi, _ = state.pi_apply(rng=rng,
+                                    params=params,
+                                    observation=observations1,
+                                    R=R)
     q1_pi = state.q1_apply(jax.lax.stop_gradient(params),
                            observation=observations1,
                            R=R,
@@ -199,6 +201,32 @@
 
 
 @jax.jit
+def compute_loss_alpha(state: TrainState, rng: PRNGKey, params,
+                       data: ArrayLike):
+    """Computes the Soft Actor-Critic loss for alpha."""
+    observations1 = data['observations1']
+    R = data['goals']
+    pi, logp_pi, _ = jax.lax.stop_gradient(
+        state.pi_apply(rng=rng, params=params, observation=observations1, R=R))
+
+    return (-jax.numpy.exp(params['logalpha']) *
+            (logp_pi + state.target_entropy)).mean(), logp_pi.mean()
+
+
+@jax.jit
+def compute_batched_loss_q(state: TrainState, rng: PRNGKey, params,
+                           data: ArrayLike):
+
+    def bound_compute_loss_q(rng, data):
+        return compute_loss_q(state, rng, params, data)
+
+    return jax.vmap(bound_compute_loss_q)(
+        jax.random.split(rng, FLAGS.num_agents),
+        data,
+    ).mean()
+
+
+@jax.jit
 def compute_batched_loss_pi(state: TrainState, rng: PRNGKey, params,
                             data: ArrayLike):
 
@@ -212,33 +240,21 @@
 
 
 @jax.jit
-def compute_loss_alpha(state: TrainState, rng: PRNGKey, params,
-                       data: ArrayLike):
-    """Computes the Soft Actor-Critic loss for alpha."""
-    observations1 = data['observations1']
-    R = data['goals']
-    pi, logp_pi, _, _ = jax.lax.stop_gradient(
-        state.pi_apply(rng=rng, params=params, R=R, observation=observations1))
-
-    return (-jax.numpy.exp(params['logalpha']) *
-            (logp_pi + state.target_entropy)).mean()
-
-
-@jax.jit
 def compute_batched_loss_alpha(state: TrainState, rng: PRNGKey, params,
                                data: ArrayLike):
 
     def bound_compute_loss_alpha(rng, data):
         return compute_loss_alpha(state, rng, params, data)
 
-    return jax.vmap(bound_compute_loss_alpha)(
+    loss, entropy = jax.vmap(bound_compute_loss_alpha)(
         jax.random.split(rng, FLAGS.num_agents),
         data,
-    ).mean()
+    )
+    return (loss.mean(), entropy.mean())
 
 
 @jax.jit
-def train_step(state: TrainState, data, action_data, update_rng: PRNGKey,
+def train_step(state: TrainState, data, update_rng: PRNGKey,
                step: int) -> TrainState:
     """Updates the parameters for Q, Pi, target Q, and alpha."""
     update_rng, q_grad_rng = jax.random.split(update_rng)
@@ -253,16 +269,11 @@
 
     state = state.q_apply_gradients(step=step, grads=q_grads)
 
-    update_rng, pi_grad_rng = jax.random.split(update_rng)
-
     # Update pi
+    update_rng, pi_grad_rng = jax.random.split(update_rng)
     pi_grad_fn = jax.value_and_grad(lambda params: compute_batched_loss_pi(
-        state, pi_grad_rng, params, action_data))
+        state, pi_grad_rng, params, data))
     pi_loss, pi_grads = pi_grad_fn(state.params)
-
-    print_nan(step, pi_loss)
-    print_nan(step, pi_grads)
-
     state = state.pi_apply_gradients(step=step, grads=pi_grads)
 
     update_rng, alpha_grad_rng = jax.random.split(update_rng)
@@ -271,15 +282,18 @@
         # Update alpha
         alpha_grad_fn = jax.value_and_grad(
             lambda params: compute_batched_loss_alpha(state, alpha_grad_rng,
-                                                      params, data))
-        alpha_loss, alpha_grads = alpha_grad_fn(state.params)
+                                                      params, data),
+            has_aux=True,
+        )
+        (alpha_loss, entropy), alpha_grads = alpha_grad_fn(state.params)
         print_nan(step, alpha_loss)
         print_nan(step, alpha_grads)
         state = state.alpha_apply_gradients(step=step, grads=alpha_grads)
     else:
+        entropy = 0.0
         alpha_loss = 0.0
 
-    return state, q_loss, pi_loss, alpha_loss
+    return state, q_loss, pi_loss, alpha_loss, entropy
 
 
 @jax.jit
@@ -304,20 +318,22 @@
 
         def true_fn(i):
             # We are at the beginning of the process, pick a random action.
-            return state.problem.random_actions(action_rng, FLAGS.num_agents)
+            return state.problem.random_actions(action_rng,
+                                                X=observation,
+                                                goal=R,
+                                                dimensions=FLAGS.num_agents)
 
         def false_fn(i):
             # We are past the beginning of the process, use the trained network.
-            pi_action, logp_pi, std, random_sample = state.pi_apply(
-                rng=action_rng,
-                params=state.params,
-                observation=observation,
-                R=R,
-                deterministic=False)
+            pi_action, _, _ = state.pi_apply(rng=action_rng,
+                                             params=state.params,
+                                             observation=observation,
+                                             R=R,
+                                             deterministic=False)
             return pi_action
 
         pi_action = jax.lax.cond(
-            step <= FLAGS.start_steps,
+            step <= FLAGS.random_sample_steps,
             true_fn,
             false_fn,
             i,
@@ -328,11 +344,9 @@
             lambda o, pi: state.problem.integrate_dynamics(o, pi),
             in_axes=(0, 0))(observation, pi_action)
 
-        # Soft Actor-Critic is designed to maximize reward.  LQR minimizes
-        # cost.  There is nothing which assumes anything about the sign of
-        # the reward, so use the negative of the cost.
-        reward = -jax.vmap(state.problem.cost)(
-            X=observation2, U=pi_action, goal=R)
+        reward = jax.vmap(state.problem.reward)(X=observation2,
+                                                U=pi_action,
+                                                goal=R)
 
         replay_buffer_state = state.replay_buffer.add(
             replay_buffer_state, {
@@ -357,10 +371,8 @@
                      step: int):
     rng, sample_rng = jax.random.split(rng)
 
-    action_data = state.replay_buffer.sample(replay_buffer_state, sample_rng)
-
     def update_iteration(i, val):
-        rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, action_data = val
+        rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, entropy = val
         rng, sample_rng, update_rng = jax.random.split(rng, 3)
 
         batch = state.replay_buffer.sample(replay_buffer_state, sample_rng)
@@ -368,22 +380,18 @@
         print_nan(i, replay_buffer_state)
         print_nan(i, batch)
 
-        state, q_loss, pi_loss, alpha_loss = train_step(
-            state,
-            data=batch.experience,
-            action_data=batch.experience,
-            update_rng=update_rng,
-            step=i)
+        state, q_loss, pi_loss, alpha_loss, entropy = train_step(
+            state, data=batch.experience, update_rng=update_rng, step=i)
 
-        return rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, action_data
+        return rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, entropy
 
-    rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, action_data = jax.lax.fori_loop(
+    rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, entropy = jax.lax.fori_loop(
         step, step + FLAGS.horizon + 1, update_iteration,
-        (rng, state, 0.0, 0.0, 0.0, replay_buffer_state, action_data))
+        (rng, state, 0.0, 0.0, 0.0, replay_buffer_state, 0))
 
     state = state.target_apply_gradients(step=state.step)
 
-    return rng, state, q_loss, pi_loss, alpha_loss
+    return rng, state, q_loss, pi_loss, alpha_loss, entropy
 
 
 def train(workdir: str, problem: Problem) -> train_state.TrainState:
@@ -396,6 +404,8 @@
     run['hparams'] = {
         'q_learning_rate': FLAGS.q_learning_rate,
         'pi_learning_rate': FLAGS.pi_learning_rate,
+        'alpha_learning_rate': FLAGS.alpha_learning_rate,
+        'random_sample_steps': FLAGS.random_sample_steps,
         'batch_size': FLAGS.batch_size,
         'horizon': FLAGS.horizon,
         'warmup_steps': FLAGS.warmup_steps,
@@ -422,11 +432,11 @@
         problem,
         q_learning_rate=q_learning_rate,
         pi_learning_rate=pi_learning_rate,
+        alpha_learning_rate=FLAGS.alpha_learning_rate,
     )
     state = restore_checkpoint(state, workdir)
 
-    state_sharding = nn.get_sharding(state, state.mesh)
-    logging.info(state_sharding)
+    logging.debug(nn.get_sharding(state, state.mesh))
 
     replay_buffer_state = state.replay_buffer.init({
         'observations1':
@@ -438,12 +448,10 @@
         'rewards':
         jax.numpy.zeros((1, )),
         'goals':
-        jax.numpy.zeros((problem.num_states, )),
+        jax.numpy.zeros((problem.num_goals, )),
     })
 
-    replay_buffer_state_sharding = nn.get_sharding(replay_buffer_state,
-                                                   state.mesh)
-    logging.info(replay_buffer_state_sharding)
+    logging.debug(nn.get_sharding(replay_buffer_state, state.mesh))
 
     # Number of gradients to accumulate before doing decent.
     update_after = FLAGS.batch_size // FLAGS.num_agents
@@ -461,24 +469,25 @@
         )
 
         def nop(rng, state, replay_buffer_state, step):
-            return rng, state.update_step(step=step), 0.0, 0.0, 0.0
+            return rng, state.update_step(step=step), 0.0, 0.0, 0.0, 0.0
 
         # Train
-        rng, state, q_loss, pi_loss, alpha_loss = jax.lax.cond(
+        rng, state, q_loss, pi_loss, alpha_loss, entropy = jax.lax.cond(
             step >= update_after, update_gradients, nop, rng, state,
             replay_buffer_state, step)
 
-        return state, replay_buffer_state, rng, q_loss, pi_loss, alpha_loss
+        return state, replay_buffer_state, rng, q_loss, pi_loss, alpha_loss, entropy
 
+    last_time = time.time()
     for step in range(0, FLAGS.steps, FLAGS.horizon):
-        state, replay_buffer_state, rng, q_loss, pi_loss, alpha_loss = train_loop(
+        state, replay_buffer_state, rng, q_loss, pi_loss, alpha_loss, entropy = train_loop(
             state, replay_buffer_state, rng, step)
 
         if FLAGS.debug_nan and has_nan(state.params):
             logging.fatal('Nan params, aborting')
 
         logging.info(
-            'Step %s: q_loss=%s, pi_loss=%s, alpha_loss=%s, q_learning_rate=%s, pi_learning_rate=%s, alpha=%s',
+            'Step %s: q_loss=%s, pi_loss=%s, alpha_loss=%s, q_learning_rate=%s, pi_learning_rate=%s, alpha=%s, entropy=%s, random=%s',
             step,
             q_loss,
             pi_loss,
@@ -486,6 +495,8 @@
             q_learning_rate(step),
             pi_learning_rate(step),
             jax.numpy.exp(state.params['logalpha']),
+            entropy,
+            step <= FLAGS.random_sample_steps,
         )
 
         run.track(
@@ -493,12 +504,14 @@
                 'q_loss': float(q_loss),
                 'pi_loss': float(pi_loss),
                 'alpha_loss': float(alpha_loss),
-                'alpha': float(jax.numpy.exp(state.params['logalpha']))
+                'alpha': float(jax.numpy.exp(state.params['logalpha'])),
+                'entropy': entropy,
             },
             step=step)
 
-        if step % 1000 == 0 and step > update_after:
+        if time.time() > last_time + 3.0 and step > update_after:
             # TODO(austin): Simulate a rollout and accumulate the reward.  How good are we doing?
             save_checkpoint(state, workdir)
+            last_time = time.time()
 
     return state
diff --git a/frc971/imu_fdcan/README.md b/frc971/imu_fdcan/README.md
index 8e82f73..7bf7328 100644
--- a/frc971/imu_fdcan/README.md
+++ b/frc971/imu_fdcan/README.md
@@ -23,7 +23,8 @@
     * The main code lives in [`Dual_IMU/Core/Src`](/Dual_IMU/Core/Src/). Make sure your changes happen inside sections marked `/* USER CODE BEGIN ... */` `/* USER CODE END ... */`. Code outside these markers will be overwritten by CubeIDE when generating code after changes to the `.ioc` file.
 3) Build + Run:
     * Option 1: Open CubeIDE GUI to build, debug, or run.
-    * Option 2: 
+    <!-- TODO(sindy): fix this build script -->
+    * Option 2 (DO NOT USE. NOT SAFE) 
         1) SSH onto the build server. 
         2) Run `bazel build -c opt --config=cortex-m4f-imu //frc971/imu_fdcan/Dual_IMU/Core:main.elf`. The output .elf file should be in bazel-bin/frc971/imu_fdcan/Dual_IMU/Core.
         3) (If deploying code locally) Move file to local directory. For example: `scp <username>@build.frc971.org:<path/to/main.elf> <local/path/to/save/file/`. A good spot to put this locally is ./Dual_IMU/Debug/.
diff --git a/frc971/wpilib/talonfx.cc b/frc971/wpilib/talonfx.cc
index 4338cfd..0da7603 100644
--- a/frc971/wpilib/talonfx.cc
+++ b/frc971/wpilib/talonfx.cc
@@ -58,11 +58,12 @@
 
 void TalonFX::WriteConfigs() {
   ctre::phoenix6::configs::CurrentLimitsConfigs current_limits;
-  current_limits.StatorCurrentLimit = stator_current_limit_;
+  current_limits.StatorCurrentLimit =
+      units::current::ampere_t{stator_current_limit_};
   current_limits.StatorCurrentLimitEnable = true;
-  current_limits.SupplyCurrentLimit = supply_current_limit_;
+  current_limits.SupplyCurrentLimit =
+      units::current::ampere_t{supply_current_limit_};
   current_limits.SupplyCurrentLimitEnable = true;
-  current_limits.SupplyTimeThreshold = 0.0;
 
   ctre::phoenix6::configs::MotorOutputConfigs output_configs;
   output_configs.NeutralMode = neutral_mode_;
diff --git a/frc971/zeroing/continuous_absolute_encoder.cc b/frc971/zeroing/continuous_absolute_encoder.cc
index 3381833..733d46d 100644
--- a/frc971/zeroing/continuous_absolute_encoder.cc
+++ b/frc971/zeroing/continuous_absolute_encoder.cc
@@ -165,4 +165,14 @@
   return builder.Finish();
 }
 
+void ContinuousAbsoluteEncoderZeroingEstimator::GetEstimatorState(
+    AbsoluteEncoderEstimatorStateStatic *fbs) const {
+  errors_.ToStaticFlatbuffer(fbs->add_errors());
+
+  fbs->set_error(error_);
+  fbs->set_zeroed(zeroed_);
+  fbs->set_position(position_);
+  fbs->set_absolute_position(filtered_absolute_encoder_);
+}
+
 }  // namespace frc971::zeroing
diff --git a/frc971/zeroing/continuous_absolute_encoder.h b/frc971/zeroing/continuous_absolute_encoder.h
index 4994280..5e700ee 100644
--- a/frc971/zeroing/continuous_absolute_encoder.h
+++ b/frc971/zeroing/continuous_absolute_encoder.h
@@ -47,6 +47,8 @@
   virtual flatbuffers::Offset<State> GetEstimatorState(
       flatbuffers::FlatBufferBuilder *fbb) const override;
 
+  void GetEstimatorState(AbsoluteEncoderEstimatorStateStatic *fbs) const;
+
  private:
   struct PositionStruct {
     PositionStruct(const AbsolutePosition &position_buffer)
diff --git a/frc971/zeroing/zeroing.h b/frc971/zeroing/zeroing.h
index 68ef38f..01c2c61 100644
--- a/frc971/zeroing/zeroing.h
+++ b/frc971/zeroing/zeroing.h
@@ -10,7 +10,7 @@
 #include "flatbuffers/flatbuffers.h"
 
 #include "frc971/constants.h"
-#include "frc971/control_loops/control_loops_generated.h"
+#include "frc971/control_loops/control_loops_static.h"
 
 // TODO(pschrader): Flag an error if encoder index pulse is not n revolutions
 // away from the last one (i.e. got extra counts from noise, etc..)
diff --git a/scouting/scouting_test.cy.js b/scouting/scouting_test.cy.js
index 990e69f..7157880 100644
--- a/scouting/scouting_test.cy.js
+++ b/scouting/scouting_test.cy.js
@@ -119,7 +119,7 @@
       ' Ended Match; stageType: kHARMONY, trapNote: false, spotlight: false '
     );
   // Ensure that the penalties action is only submitted once.
-  cy.get('#review_data li').contains('Penalties').should('have.length', 1);
+  cy.get('#review_data li:contains("Penalties")').its('length').should('eq', 1);
 
   clickButton('Submit');
   headerShouldBe(teamNumber + ' Success ');
diff --git a/y2020/wpilib_interface.cc b/y2020/wpilib_interface.cc
index 2213c16..8c8b480 100644
--- a/y2020/wpilib_interface.cc
+++ b/y2020/wpilib_interface.cc
@@ -123,9 +123,11 @@
 void WriteConfigs(ctre::phoenix6::hardware::TalonFX *talon,
                   double stator_current_limit, double supply_current_limit) {
   ctre::phoenix6::configs::CurrentLimitsConfigs current_limits;
-  current_limits.StatorCurrentLimit = stator_current_limit;
+  current_limits.StatorCurrentLimit =
+      units::current::ampere_t{stator_current_limit};
   current_limits.StatorCurrentLimitEnable = true;
-  current_limits.SupplyCurrentLimit = supply_current_limit;
+  current_limits.SupplyCurrentLimit =
+      units::current::ampere_t{supply_current_limit};
   current_limits.SupplyCurrentLimitEnable = true;
 
   ctre::phoenix6::configs::TalonFXConfiguration configuration;
@@ -454,7 +456,8 @@
       ::std::unique_ptr<::ctre::phoenix6::hardware::TalonFX> t) {
     climber_falcon_ = ::std::move(t);
     ctre::phoenix6::configs::CurrentLimitsConfigs current_limits;
-    current_limits.SupplyCurrentLimit = Values::kClimberSupplyCurrentLimit();
+    current_limits.SupplyCurrentLimit =
+        units::current::ampere_t{Values::kClimberSupplyCurrentLimit()};
     current_limits.SupplyCurrentLimitEnable = true;
 
     ctre::phoenix6::configs::TalonFXConfiguration configuration;
diff --git a/y2021_bot3/wpilib_interface.cc b/y2021_bot3/wpilib_interface.cc
index 4006736..f3c956e 100644
--- a/y2021_bot3/wpilib_interface.cc
+++ b/y2021_bot3/wpilib_interface.cc
@@ -119,9 +119,11 @@
 void WriteConfigs(ctre::phoenix6::hardware::TalonFX *talon,
                   double stator_current_limit, double supply_current_limit) {
   ctre::phoenix6::configs::CurrentLimitsConfigs current_limits;
-  current_limits.StatorCurrentLimit = stator_current_limit;
+  current_limits.StatorCurrentLimit =
+      units::current::ampere_t{stator_current_limit};
   current_limits.StatorCurrentLimitEnable = true;
-  current_limits.SupplyCurrentLimit = supply_current_limit;
+  current_limits.SupplyCurrentLimit =
+      units::current::ampere_t{supply_current_limit};
   current_limits.SupplyCurrentLimitEnable = true;
 
   ctre::phoenix6::configs::TalonFXConfiguration configuration;
@@ -249,7 +251,8 @@
       ::std::unique_ptr<::ctre::phoenix6::hardware::TalonFX> t) {
     climber_falcon_ = ::std::move(t);
     ctre::phoenix6::configs::CurrentLimitsConfigs current_limits;
-    current_limits.SupplyCurrentLimit = Values::kClimberSupplyCurrentLimit();
+    current_limits.SupplyCurrentLimit =
+        units::current::ampere_t{Values::kClimberSupplyCurrentLimit()};
     current_limits.SupplyCurrentLimitEnable = true;
 
     ctre::phoenix6::configs::TalonFXConfiguration configuration;
diff --git a/y2022/wpilib_interface.cc b/y2022/wpilib_interface.cc
index 1f82e91..469c93b 100644
--- a/y2022/wpilib_interface.cc
+++ b/y2022/wpilib_interface.cc
@@ -129,9 +129,11 @@
 void WriteConfigs(ctre::phoenix6::hardware::TalonFX *talon,
                   double stator_current_limit, double supply_current_limit) {
   ctre::phoenix6::configs::CurrentLimitsConfigs current_limits;
-  current_limits.StatorCurrentLimit = stator_current_limit;
+  current_limits.StatorCurrentLimit =
+      units::current::ampere_t{stator_current_limit};
   current_limits.StatorCurrentLimitEnable = true;
-  current_limits.SupplyCurrentLimit = supply_current_limit;
+  current_limits.SupplyCurrentLimit =
+      units::current::ampere_t{supply_current_limit};
   current_limits.SupplyCurrentLimitEnable = true;
 
   ctre::phoenix6::configs::TalonFXConfiguration configuration;
@@ -483,10 +485,10 @@
     for (auto &falcon : {catapult_falcon_1_can_, catapult_falcon_2_can_}) {
       ctre::phoenix6::configs::CurrentLimitsConfigs current_limits;
       current_limits.StatorCurrentLimit =
-          Values::kIntakeRollerStatorCurrentLimit();
+          units::current::ampere_t{Values::kIntakeRollerStatorCurrentLimit()};
       current_limits.StatorCurrentLimitEnable = true;
       current_limits.SupplyCurrentLimit =
-          Values::kIntakeRollerSupplyCurrentLimit();
+          units::current::ampere_t{Values::kIntakeRollerSupplyCurrentLimit()};
       current_limits.SupplyCurrentLimitEnable = true;
 
       ctre::phoenix6::configs::TalonFXConfiguration configuration;
diff --git a/y2023/wpilib_interface.cc b/y2023/wpilib_interface.cc
index bc14311..536f7bb 100644
--- a/y2023/wpilib_interface.cc
+++ b/y2023/wpilib_interface.cc
@@ -166,11 +166,11 @@
     inverted_ = invert;
 
     ctre::phoenix6::configs::CurrentLimitsConfigs current_limits;
-    current_limits.StatorCurrentLimit =
-        constants::Values::kDrivetrainStatorCurrentLimit();
+    current_limits.StatorCurrentLimit = units::current::ampere_t{
+        constants::Values::kDrivetrainStatorCurrentLimit()};
     current_limits.StatorCurrentLimitEnable = true;
-    current_limits.SupplyCurrentLimit =
-        constants::Values::kDrivetrainSupplyCurrentLimit();
+    current_limits.SupplyCurrentLimit = units::current::ampere_t{
+        constants::Values::kDrivetrainSupplyCurrentLimit()};
     current_limits.SupplyCurrentLimitEnable = true;
 
     ctre::phoenix6::configs::MotorOutputConfigs output_configs;
@@ -196,11 +196,11 @@
 
   void WriteRollerConfigs() {
     ctre::phoenix6::configs::CurrentLimitsConfigs current_limits;
-    current_limits.StatorCurrentLimit =
-        constants::Values::kRollerStatorCurrentLimit();
+    current_limits.StatorCurrentLimit = units::current::ampere_t{
+        constants::Values::kRollerStatorCurrentLimit()};
     current_limits.StatorCurrentLimitEnable = true;
-    current_limits.SupplyCurrentLimit =
-        constants::Values::kRollerSupplyCurrentLimit();
+    current_limits.SupplyCurrentLimit = units::current::ampere_t{
+        constants::Values::kRollerSupplyCurrentLimit()};
     current_limits.SupplyCurrentLimitEnable = true;
 
     ctre::phoenix6::configs::MotorOutputConfigs output_configs;
diff --git a/y2023_bot3/wpilib_interface.cc b/y2023_bot3/wpilib_interface.cc
index d8813be..0e43e10 100644
--- a/y2023_bot3/wpilib_interface.cc
+++ b/y2023_bot3/wpilib_interface.cc
@@ -147,11 +147,11 @@
     inverted_ = invert;
 
     ctre::phoenix6::configs::CurrentLimitsConfigs current_limits;
-    current_limits.StatorCurrentLimit =
-        constants::Values::kDrivetrainStatorCurrentLimit();
+    current_limits.StatorCurrentLimit = units::current::ampere_t{
+        constants::Values::kDrivetrainStatorCurrentLimit()};
     current_limits.StatorCurrentLimitEnable = true;
-    current_limits.SupplyCurrentLimit =
-        constants::Values::kDrivetrainSupplyCurrentLimit();
+    current_limits.SupplyCurrentLimit = units::current::ampere_t{
+        constants::Values::kDrivetrainSupplyCurrentLimit()};
     current_limits.SupplyCurrentLimitEnable = true;
 
     ctre::phoenix6::configs::MotorOutputConfigs output_configs;
diff --git a/y2024_swerve/y2024_swerve_roborio.json b/y2024_swerve/y2024_swerve_roborio.json
index 705370d..adf528d 100644
--- a/y2024_swerve/y2024_swerve_roborio.json
+++ b/y2024_swerve/y2024_swerve_roborio.json
@@ -305,8 +305,7 @@
             "name": "wpilib_interface",
             "executable_name": "wpilib_interface",
             "args": [
-                "--nodie_on_malloc",
-                "--ctre_diag_server"
+              "--nodie_on_malloc"
             ],
             "nodes": [
                 "roborio"