Add code to train a simple turret controller

This uses Soft Actor-Critic with automatic temperature adjustment to
train a controller for last year's turret.  This is a stepping step to
training a more complicated turret for the swerve.  I need to still
parallelize it, play with the hyper-parameters, and teach it how to go
to a nonzero goal.

Change-Id: I1357b5fbf8549acac4ee0b94ef8f2636867c28ad
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/physics_test.py b/frc971/control_loops/swerve/physics_test.py
index 6e7ddf4..a45bf46 100644
--- a/frc971/control_loops/swerve/physics_test.py
+++ b/frc971/control_loops/swerve/physics_test.py
@@ -59,6 +59,7 @@
         Xdot = self.position_swerve_full_dynamics(X, U)
         Xdot_jax = jax_dynamics.full_dynamics(self.coefficients, X[:, 0], U[:,
                                                                             0])
+        self.assertEqual(Xdot.shape[0], Xdot_jax.shape[0])
 
         self.assertLess(
             numpy.linalg.norm(Xdot[:, 0] - Xdot_jax),
@@ -71,6 +72,9 @@
         velocity_physics_jax = jax_dynamics.velocity_dynamics(
             self.coefficients, X_velocity[:, 0], U[:, 0])
 
+        self.assertEqual(velocity_physics.shape[0],
+                         velocity_physics_jax.shape[0])
+
         self.assertLess(
             numpy.linalg.norm(velocity_physics[:, 0] - velocity_physics_jax),
             2e-8,
diff --git a/frc971/control_loops/swerve/velocity_controller/BUILD b/frc971/control_loops/swerve/velocity_controller/BUILD
index 5e06420..f0e199b 100644
--- a/frc971/control_loops/swerve/velocity_controller/BUILD
+++ b/frc971/control_loops/swerve/velocity_controller/BUILD
@@ -22,3 +22,49 @@
         "@pip//jax",
     ],
 )
+
+py_binary(
+    name = "main",
+    srcs = [
+        "main.py",
+        "model.py",
+        "physics.py",
+        "train.py",
+    ],
+    deps = [
+        ":experience_buffer",
+        "//frc971/control_loops/swerve:jax_dynamics",
+        "@pip//absl_py",
+        "@pip//aim",
+        "@pip//clu",
+        "@pip//flashbax",
+        "@pip//flax",
+        "@pip//jax",
+        "@pip//jaxtyping",
+        "@pip//matplotlib",
+        "@pip//numpy",
+        "@pip//tensorflow",
+    ],
+)
+
+py_binary(
+    name = "lqr_plot",
+    srcs = [
+        "lqr_plot.py",
+        "model.py",
+        "physics.py",
+    ],
+    deps = [
+        ":experience_buffer",
+        "//frc971/control_loops/swerve:jax_dynamics",
+        "@pip//absl_py",
+        "@pip//flashbax",
+        "@pip//flax",
+        "@pip//jax",
+        "@pip//jaxtyping",
+        "@pip//matplotlib",
+        "@pip//numpy",
+        "@pip//pygobject",
+        "@pip//tensorflow",
+    ],
+)
diff --git a/frc971/control_loops/swerve/velocity_controller/lqr_plot.py b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
new file mode 100644
index 0000000..30f5630
--- /dev/null
+++ b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
@@ -0,0 +1,489 @@
+#!/usr/bin/env python3
+
+import os
+
+# Setup JAX to run on the CPU
+os.environ['XLA_FLAGS'] = ' '.join([
+    # Teach it where to find CUDA
+    '--xla_gpu_cuda_data_dir=/usr/lib/cuda',
+    # Use up to 20 cores
+    '--xla_force_host_platform_device_count=20',
+    # Dump XLA to /tmp/foo to aid debugging
+    '--xla_dump_to=/tmp/foo',
+    '--xla_gpu_enable_command_buffer='
+])
+
+os.environ['JAX_PLATFORMS'] = 'cpu'
+# Don't pre-allocate memory
+os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
+
+import absl
+from absl import logging
+from matplotlib.animation import FuncAnimation
+import matplotlib
+import numpy
+import scipy
+import time
+
+matplotlib.use("gtk3agg")
+
+from matplotlib import pylab
+from matplotlib import pyplot
+from flax.training import checkpoints
+from frc971.control_loops.python import controls
+import tensorflow as tf
+import threading
+import collections
+
+import jax
+
+jax._src.deprecations.accelerate('tracer-hash')
+jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
+jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
+jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
+
+from frc971.control_loops.swerve.velocity_controller.model import *
+from frc971.control_loops.swerve.velocity_controller.physics import *
+
+# Container for the data.
+Data = collections.namedtuple('Data', [
+    't', 'X', 'X_lqr', 'U', 'U_lqr', 'cost', 'cost_lqr', 'q1_grid', 'q2_grid',
+    'q_grid', 'target_q_grid', 'lqr_grid', 'pi_grid_U', 'lqr_grid_U', 'grid_X',
+    'grid_Y', 'reward', 'reward_lqr', 'step'
+])
+
+FLAGS = absl.flags.FLAGS
+
+absl.flags.DEFINE_string('workdir', None, 'Directory to store model data.')
+
+absl.flags.DEFINE_integer(
+    'horizon',
+    default=100,
+    help='Horizon to simulate',
+)
+
+absl.flags.DEFINE_float(
+    'alpha',
+    default=0.2,
+    help='Entropy.  If negative, automatically solve for it.',
+)
+
+numpy.set_printoptions(linewidth=200, )
+
+
+def restore_checkpoint(state: TrainState, workdir: str):
+    return checkpoints.restore_checkpoint(workdir, state)
+
+
+# Hard-coded simulation parameters for the turret.
+dt = 0.005
+A = numpy.matrix([[1., 0.00456639], [0., 0.83172142]])
+B = numpy.matrix([[0.00065992], [0.25610763]])
+
+Q = numpy.matrix([[2.77777778, 0.], [0., 0.01]])
+R = numpy.matrix([[0.00694444]])
+
+# Compute the optimal LQR cost + controller.
+F, P = controls.dlqr(A, B, Q, R, optimal_cost_function=True)
+
+
+def generate_data(step=None):
+    grid_X = numpy.arange(-1, 1, 0.1)
+    grid_Y = numpy.arange(-10, 10, 0.1)
+    grid_X, grid_Y = numpy.meshgrid(grid_X, grid_Y)
+    grid_X = jax.numpy.array(grid_X)
+    grid_Y = jax.numpy.array(grid_Y)
+    # Load the training state.
+    physics_constants = jax_dynamics.Coefficients()
+
+    rng = jax.random.key(0)
+    rng, init_rng = jax.random.split(rng)
+
+    state = create_train_state(init_rng, physics_constants,
+                               FLAGS.q_learning_rate, FLAGS.pi_learning_rate)
+
+    state = restore_checkpoint(state, FLAGS.workdir)
+    if step is not None and state.step == step:
+        return None
+
+    print('F:', F)
+    print('P:', P)
+
+    X = jax.numpy.array([1.0, 0.0])
+    X_lqr = X.copy()
+    goal = jax.numpy.array([0.0, 0.0])
+
+    logging.info('X: %s', X)
+    logging.info('goal: %s', goal)
+    logging.debug('params: %s', state.params)
+
+    # Compute the various cost surfaces for plotting.
+    def compute_q1(X, Y):
+        return state.q1_apply(
+            state.params,
+            unwrap_angles(jax.numpy.array([X, Y])),
+            jax.numpy.array([0.]),
+        )[0]
+
+    def compute_q2(X, Y):
+        return state.q2_apply(
+            state.params,
+            unwrap_angles(jax.numpy.array([X, Y])),
+            jax.numpy.array([0.]),
+        )[0]
+
+    def lqr_cost(X, Y):
+        x = jax.numpy.array([X, Y])
+        return -x.T @ jax.numpy.array(P) @ x
+
+    def compute_q(params, x, y):
+        X = unwrap_angles(jax.numpy.array([x, y]))
+
+        return jax.numpy.minimum(
+            state.q1_apply(
+                params,
+                X,
+                jax.numpy.array([0.]),
+            )[0],
+            state.q2_apply(
+                params,
+                X,
+                jax.numpy.array([0.]),
+            )[0])
+
+    cost_grid1 = jax.vmap(jax.vmap(compute_q1))(grid_X, grid_Y)
+    cost_grid2 = jax.vmap(jax.vmap(compute_q2))(grid_X, grid_Y)
+    cost_grid = jax.vmap(jax.vmap(lambda x, y: compute_q(state.params, x, y)))(
+        grid_X, grid_Y)
+    target_cost_grid = jax.vmap(
+        jax.vmap(lambda x, y: compute_q(state.target_params, x, y)))(grid_X,
+                                                                     grid_Y)
+    lqr_cost_grid = jax.vmap(jax.vmap(lqr_cost))(grid_X, grid_Y)
+
+    # TODO(austin): Stuff to figure out:
+    # 3: Make it converge faster.  Use both GPUs better?
+    # 4: Can we feed in a reference position and get it to learn to stabilize there?
+
+    # Now compute the two controller surfaces.
+    def compute_lqr_U(X, Y):
+        x = jax.numpy.array([X, Y])
+        return (-jax.numpy.array(F.reshape((2, ))) @ x)[0]
+
+    def compute_pi_U(X, Y):
+        x = jax.numpy.array([X, Y])
+        U, _, _, _ = state.pi_apply(rng,
+                                    state.params,
+                                    observation=unwrap_angles(x),
+                                    deterministic=True)
+        return U[0]
+
+    lqr_cost_U = jax.vmap(jax.vmap(compute_lqr_U))(grid_X, grid_Y)
+    pi_cost_U = jax.vmap(jax.vmap(compute_pi_U))(grid_X, grid_Y)
+
+    logging.info('Finished cost')
+
+    # Now simulate the robot, accumulating up things to plot.
+    def loop(i, val):
+        X, X_lqr, data, params = val
+        t = data.t.at[i].set(i * dt)
+
+        U, _, _, _ = state.pi_apply(rng,
+                                    params,
+                                    observation=unwrap_angles(X),
+                                    deterministic=True)
+        U_lqr = F @ (goal - X_lqr)
+
+        cost = jax.numpy.minimum(state.q1_apply(params, unwrap_angles(X), U),
+                                 state.q2_apply(params, unwrap_angles(X), U))
+
+        U_plot = data.U.at[i, :].set(U)
+        U_lqr_plot = data.U_lqr.at[i, :].set(U_lqr)
+        X_plot = data.X.at[i, :].set(X)
+        X_lqr_plot = data.X_lqr.at[i, :].set(X_lqr)
+        cost_plot = data.cost.at[i, :].set(cost)
+        cost_lqr_plot = data.cost_lqr.at[i, :].set(lqr_cost(*X_lqr))
+
+        X = A @ X + B @ U
+        X_lqr = A @ X_lqr + B @ U_lqr
+
+        reward = data.reward - physics.state_cost(X, U, goal)
+        reward_lqr = data.reward_lqr - physics.state_cost(X_lqr, U_lqr, goal)
+
+        return X, X_lqr, data._replace(
+            t=t,
+            U=U_plot,
+            U_lqr=U_lqr_plot,
+            X=X_plot,
+            X_lqr=X_lqr_plot,
+            cost=cost_plot,
+            cost_lqr=cost_lqr_plot,
+            reward=reward,
+            reward_lqr=reward_lqr,
+        ), params
+
+    # Do it.
+    @jax.jit
+    def integrate(data, X, X_lqr, params):
+        return jax.lax.fori_loop(0, FLAGS.horizon, loop,
+                                 (X, X_lqr, data, params))
+
+    X, X_lqr, data, params = integrate(
+        Data(
+            t=jax.numpy.zeros((FLAGS.horizon, )),
+            X=jax.numpy.zeros((FLAGS.horizon, physics.NUM_STATES)),
+            X_lqr=jax.numpy.zeros((FLAGS.horizon, physics.NUM_STATES)),
+            U=jax.numpy.zeros((FLAGS.horizon, physics.NUM_OUTPUTS)),
+            U_lqr=jax.numpy.zeros((FLAGS.horizon, physics.NUM_OUTPUTS)),
+            cost=jax.numpy.zeros((FLAGS.horizon, 1)),
+            cost_lqr=jax.numpy.zeros((FLAGS.horizon, 1)),
+            q1_grid=cost_grid1,
+            q2_grid=cost_grid2,
+            q_grid=cost_grid,
+            target_q_grid=target_cost_grid,
+            lqr_grid=lqr_cost_grid,
+            pi_grid_U=pi_cost_U,
+            lqr_grid_U=lqr_cost_U,
+            grid_X=grid_X,
+            grid_Y=grid_Y,
+            reward=0.0,
+            reward_lqr=0.0,
+            step=state.step,
+        ), X, X_lqr, state.params)
+
+    logging.info('Finished integrating, reward of %f, lqr reward of %f',
+                 data.reward, data.reward_lqr)
+
+    # Convert back to numpy for plotting.
+    return Data(
+        t=numpy.array(data.t),
+        X=numpy.array(data.X),
+        X_lqr=numpy.array(data.X_lqr),
+        U=numpy.array(data.U),
+        U_lqr=numpy.array(data.U_lqr),
+        cost=numpy.array(data.cost),
+        cost_lqr=numpy.array(data.cost_lqr),
+        q1_grid=numpy.array(data.q1_grid),
+        q2_grid=numpy.array(data.q2_grid),
+        q_grid=numpy.array(data.q_grid),
+        target_q_grid=numpy.array(data.target_q_grid),
+        lqr_grid=numpy.array(data.lqr_grid),
+        pi_grid_U=numpy.array(data.pi_grid_U),
+        lqr_grid_U=numpy.array(data.lqr_grid_U),
+        grid_X=numpy.array(data.grid_X),
+        grid_Y=numpy.array(data.grid_Y),
+        reward=float(data.reward),
+        reward_lqr=float(data.reward_lqr),
+        step=data.step,
+    )
+
+
+class Plotter(object):
+
+    def __init__(self, data):
+        # Make all the plots and axis.
+        self.fig0, self.axs0 = pylab.subplots(3)
+        self.fig0.supxlabel('Seconds')
+
+        self.x, = self.axs0[0].plot([], [], label="x")
+        self.v, = self.axs0[0].plot([], [], label="v")
+        self.x_lqr, = self.axs0[0].plot([], [], label="x_lqr")
+        self.v_lqr, = self.axs0[0].plot([], [], label="v_lqr")
+
+        self.axs0[0].set_ylabel('Velocity')
+        self.axs0[0].legend()
+
+        self.uaxis, = self.axs0[1].plot([], [], label="U")
+        self.uaxis_lqr, = self.axs0[1].plot([], [], label="U_lqr")
+
+        self.axs0[1].set_ylabel('Amps')
+        self.axs0[1].legend()
+
+        self.costaxis, = self.axs0[2].plot([], [], label="cost")
+        self.costlqraxis, = self.axs0[2].plot([], [], label="cost lqr")
+        self.axs0[2].set_ylabel('Cost')
+        self.axs0[2].legend()
+
+        self.costfig = pyplot.figure(figsize=pyplot.figaspect(0.5))
+        self.cost3dax = [
+            self.costfig.add_subplot(2, 3, 1, projection='3d'),
+            self.costfig.add_subplot(2, 3, 2, projection='3d'),
+            self.costfig.add_subplot(2, 3, 3, projection='3d'),
+            self.costfig.add_subplot(2, 3, 4, projection='3d'),
+            self.costfig.add_subplot(2, 3, 5, projection='3d'),
+            self.costfig.add_subplot(2, 3, 6, projection='3d'),
+        ]
+
+        self.Ufig = pyplot.figure(figsize=pyplot.figaspect(0.5))
+        self.Uax = [
+            self.Ufig.add_subplot(1, 3, 1, projection='3d'),
+            self.Ufig.add_subplot(1, 3, 2, projection='3d'),
+            self.Ufig.add_subplot(1, 3, 3, projection='3d'),
+        ]
+
+        self.last_trajectory_step = 0
+        self.last_cost_step = 0
+        self.last_U_step = 0
+
+    def update_trajectory_plot(self, data):
+        if data.step == self.last_trajectory_step:
+            return
+        self.last_trajectory_step = data.step
+        logging.info('Updating trajectory plots')
+
+        # Put data in the trajectory plots.
+        self.x.set_data(data.t, data.X[:, 0])
+        self.v.set_data(data.t, data.X[:, 1])
+        self.x_lqr.set_data(data.t, data.X_lqr[:, 0])
+        self.v_lqr.set_data(data.t, data.X_lqr[:, 1])
+
+        self.uaxis.set_data(data.t, data.U[:, 0])
+        self.uaxis_lqr.set_data(data.t, data.U_lqr[:, 0])
+        self.costaxis.set_data(data.t, data.cost[:, 0] - data.cost[-1, 0])
+        self.costlqraxis.set_data(data.t, data.cost_lqr[:, 0])
+
+        self.axs0[0].relim()
+        self.axs0[0].autoscale_view()
+
+        self.axs0[1].relim()
+        self.axs0[1].autoscale_view()
+
+        self.axs0[2].relim()
+        self.axs0[2].autoscale_view()
+
+        return self.x, self.v, self.uaxis, self.costaxis, self.costlqraxis
+
+    def update_cost_plot(self, data):
+        if data.step == self.last_cost_step:
+            return
+        logging.info('Updating cost plots')
+        self.last_cost_step = data.step
+        # Put data in the cost plots.
+        if hasattr(self, 'costsurf'):
+            for surf in self.costsurf:
+                surf.remove()
+
+        plots = [
+            (data.q1_grid, 'q1'),
+            (data.q2_grid, 'q2'),
+            (data.q_grid, 'q'),
+            (data.target_q_grid, 'target q'),
+            (data.lqr_grid, 'lqr'),
+            (data.q_grid - data.q_grid.max() - data.lqr_grid, 'error'),
+        ]
+
+        self.costsurf = [
+            self.cost3dax[i].plot_surface(
+                data.grid_X,
+                data.grid_Y,
+                Z,
+                cmap="magma",
+                label=label,
+            ) for i, (Z, label) in enumerate(plots)
+        ]
+
+        for axis in self.cost3dax:
+            axis.legend()
+            axis.relim()
+            axis.autoscale_view()
+
+        return self.costsurf
+
+    def update_U_plot(self, data):
+        if data.step == self.last_U_step:
+            return
+        self.last_U_step = data.step
+        logging.info('Updating U plots')
+        # Put data in the controller plots.
+        if hasattr(self, 'Usurf'):
+            for surf in self.Usurf:
+                surf.remove()
+
+        plots = [
+            (data.lqr_grid_U, 'lqr'),
+            (data.pi_grid_U, 'pi'),
+            ((data.lqr_grid_U - data.pi_grid_U), 'error'),
+        ]
+
+        self.Usurf = [
+            self.Uax[i].plot_surface(
+                data.grid_X,
+                data.grid_Y,
+                Z,
+                cmap="magma",
+                label=label,
+            ) for i, (Z, label) in enumerate(plots)
+        ]
+
+        for axis in self.Uax:
+            axis.legend()
+            axis.relim()
+            axis.autoscale_view()
+
+        return self.Usurf
+
+
+def main(argv):
+    if len(argv) > 1:
+        raise absl.app.UsageError('Too many command-line arguments.')
+
+    tf.config.experimental.set_visible_devices([], 'GPU')
+
+    lock = threading.Lock()
+
+    # Load data.
+    data = generate_data()
+
+    plotter = Plotter(data)
+
+    # Event for shutting down the thread.
+    shutdown = threading.Event()
+
+    # Thread to grab new data periodically.
+    def do_update():
+        while True:
+            nonlocal data
+
+            my_data = generate_data(data.step)
+
+            if my_data is not None:
+                with lock:
+                    data = my_data
+
+            if shutdown.wait(timeout=3):
+                return
+
+    update_thread = threading.Thread(target=do_update)
+    update_thread.start()
+
+    # Now, update each of the plots every second with the new data.
+    def update0(frame):
+        with lock:
+            my_data = data
+
+        return plotter.update_trajectory_plot(my_data)
+
+    def update1(frame):
+        with lock:
+            my_data = data
+
+        return plotter.update_cost_plot(my_data)
+
+    def update2(frame):
+        with lock:
+            my_data = data
+
+        return plotter.update_U_plot(my_data)
+
+    animation0 = FuncAnimation(plotter.fig0, update0, interval=1000)
+    animation1 = FuncAnimation(plotter.costfig, update1, interval=1000)
+    animation2 = FuncAnimation(plotter.Ufig, update2, interval=1000)
+
+    pyplot.show()
+
+    shutdown.set()
+    update_thread.join()
+
+
+if __name__ == '__main__':
+    absl.flags.mark_flags_as_required(['workdir'])
+    absl.app.run(main)
diff --git a/frc971/control_loops/swerve/velocity_controller/main.py b/frc971/control_loops/swerve/velocity_controller/main.py
new file mode 100644
index 0000000..7ef687c
--- /dev/null
+++ b/frc971/control_loops/swerve/velocity_controller/main.py
@@ -0,0 +1,72 @@
+import os
+
+# Setup XLA first.
+os.environ['XLA_FLAGS'] = ' '.join([
+    # Teach it where to find CUDA
+    '--xla_gpu_cuda_data_dir=/usr/lib/cuda',
+    # Use up to 20 cores
+    #'--xla_force_host_platform_device_count=6',
+    # Dump XLA to /tmp/foo to aid debugging
+    #'--xla_dump_to=/tmp/foo',
+    #'--xla_gpu_enable_command_buffer='
+    # Dump sharding
+    #"--xla_dump_to=/tmp/foo",
+    #"--xla_dump_hlo_pass_re=spmd|propagation"
+])
+os.environ['JAX_PLATFORMS'] = 'cuda,cpu'
+os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'true'
+os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.50'
+
+from absl import app
+from absl import flags
+from absl import logging
+from clu import platform
+import jax
+import tensorflow as tf
+from frc971.control_loops.swerve import jax_dynamics
+
+jax._src.deprecations.accelerate('tracer-hash')
+# Enable the compilation cache
+jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
+jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
+jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
+jax.config.update('jax_threefry_partitionable', True)
+
+import train
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string('workdir', None, 'Directory to store model data.')
+
+
+def main(argv):
+    if len(argv) > 1:
+        raise app.UsageError('Too many command-line arguments.')
+
+    # Hide any GPUs from TensorFlow. Otherwise it might reserve memory.
+    tf.config.experimental.set_visible_devices([], 'GPU')
+
+    logging.info('JAX process: %d / %d', jax.process_index(),
+                 jax.process_count())
+    logging.info('JAX local devices: %r', jax.local_devices())
+
+    # Add a note so that we can tell which task is which JAX host.
+    # (Depending on the platform task 0 is not guaranteed to be host 0)
+    platform.work_unit().set_task_status(
+        f'process_index: {jax.process_index()}, '
+        f'process_count: {jax.process_count()}')
+    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
+                                         FLAGS.workdir, 'workdir')
+
+    logging.info(
+        'Visualize results with: bazel run -c opt @pip_deps_tensorboard//:rules_python_wheel_entry_point_tensorboard -- --logdir %s',
+        FLAGS.workdir,
+    )
+
+    physics_constants = jax_dynamics.Coefficients()
+    state = train.train(FLAGS.workdir, physics_constants)
+
+
+if __name__ == '__main__':
+    flags.mark_flags_as_required(['workdir'])
+    app.run(main)
diff --git a/frc971/control_loops/swerve/velocity_controller/model.py b/frc971/control_loops/swerve/velocity_controller/model.py
new file mode 100644
index 0000000..ae67efe
--- /dev/null
+++ b/frc971/control_loops/swerve/velocity_controller/model.py
@@ -0,0 +1,447 @@
+from __future__ import annotations
+import flax
+import flashbax
+from typing import Any
+import dataclasses
+import absl
+from absl import logging
+import numpy
+import jax
+from flax import linen as nn
+from jaxtyping import Array, ArrayLike
+import optax
+from flax.training import train_state
+from jax.experimental import mesh_utils
+from jax.sharding import Mesh, PartitionSpec, NamedSharding
+from frc971.control_loops.swerve import jax_dynamics
+from frc971.control_loops.swerve import dynamics
+from frc971.control_loops.swerve.velocity_controller import physics
+from frc971.control_loops.swerve.velocity_controller import experience_buffer
+
+from flax.typing import PRNGKey
+
+FLAGS = absl.flags.FLAGS
+
+absl.flags.DEFINE_integer(
+    'num_agents',
+    default=10,
+    help='Training batch size.',
+)
+
+absl.flags.DEFINE_float(
+    'q_learning_rate',
+    default=0.002,
+    help='Training learning rate.',
+)
+
+absl.flags.DEFINE_float(
+    'final_q_learning_rate',
+    default=0.00002,
+    help='Training learning rate.',
+)
+
+absl.flags.DEFINE_float(
+    'pi_learning_rate',
+    default=0.002,
+    help='Training learning rate.',
+)
+
+absl.flags.DEFINE_float(
+    'final_pi_learning_rate',
+    default=0.00002,
+    help='Training learning rate.',
+)
+
+absl.flags.DEFINE_integer(
+    'replay_size',
+    default=2000000,
+    help='Number of steps to save in our replay buffer',
+)
+
+absl.flags.DEFINE_integer(
+    'batch_size',
+    default=10000,
+    help='Batch size for learning Q and pi',
+)
+
+HIDDEN_WEIGHTS = 256
+
+LOG_STD_MIN = -20
+LOG_STD_MAX = 2
+
+
+def gaussian_likelihood(noise: ArrayLike, log_std: ArrayLike):
+    pre_sum = -0.5 * (noise**2 + 2 * log_std + jax.numpy.log(2 * jax.numpy.pi))
+
+    if len(pre_sum.shape) > 1:
+        return jax.numpy.sum(pre_sum, keepdims=True, axis=1)
+    else:
+        return jax.numpy.sum(pre_sum, keepdims=True)
+
+
+class SquashedGaussianMLPActor(nn.Module):
+    """Actor model."""
+
+    # Number of dimensions in the action space
+    action_space: int = 8
+
+    hidden_sizes: list[int] = dataclasses.field(
+        default_factory=lambda: [HIDDEN_WEIGHTS, HIDDEN_WEIGHTS])
+
+    # Activation function
+    activation: Callable = nn.activation.relu
+
+    # Max action we can apply
+    action_limit: float = 30.0
+
+    @nn.compact
+    def __call__(self,
+                 observations: ArrayLike,
+                 deterministic: bool = False,
+                 rng: PRNGKey | None = None):
+        x = observations
+        # Apply the dense layers
+        for i, hidden_size in enumerate(self.hidden_sizes):
+            x = nn.Dense(
+                name=f'denselayer{i}',
+                features=hidden_size,
+            )(x)
+            x = self.activation(x)
+
+        # Average policy is a dense function of the space.
+        mu = nn.Dense(
+            features=self.action_space,
+            name='mu',
+            kernel_init=nn.initializers.zeros,
+        )(x)
+
+        log_std_layer = nn.Dense(features=self.action_space,
+                                 name='log_std_layer')(x)
+
+        # Clip the log of the standard deviation in a soft manner.
+        log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (
+            flax.linen.activation.tanh(log_std_layer) + 1)
+
+        std = jax.numpy.exp(log_std)
+
+        if rng is None:
+            rng = self.make_rng('pi')
+
+        # Grab a random sample
+        random_sample = jax.random.normal(rng, shape=std.shape)
+
+        if deterministic:
+            # We are testing the optimal policy, just use the mean.
+            pi_action = mu
+        else:
+            # Use the reparameterization trick.  Adjust the unit gausian with
+            # something we can solve for to get the desired noise.
+            pi_action = random_sample * std + mu
+
+        logp_pi = gaussian_likelihood(random_sample, log_std)
+        # Adjustment to log prob
+        # NOTE: This formula is a little bit magic. To get an understanding of where it
+        # comes from, check out the original SAC paper (arXiv 1801.01290) and look in
+        # appendix C. This is a more numerically-stable equivalent to Eq 21.
+        delta = (2.0 * (jax.numpy.log(2.0) - pi_action -
+                        flax.linen.softplus(-2.0 * pi_action)))
+
+        if len(delta.shape) > 1:
+            delta = jax.numpy.sum(delta, keepdims=True, axis=1)
+        else:
+            delta = jax.numpy.sum(delta, keepdims=True)
+
+        logp_pi = logp_pi - delta
+
+        # Now, saturate the action to the limit using tanh
+        pi_action = self.action_limit * flax.linen.activation.tanh(pi_action)
+
+        return pi_action, logp_pi, self.action_limit * std, random_sample
+
+
+class MLPQFunction(nn.Module):
+
+    # Number and size of the hidden layers.
+    hidden_sizes: list[int] = dataclasses.field(
+        default_factory=lambda: [HIDDEN_WEIGHTS, HIDDEN_WEIGHTS])
+
+    activation: Callable = nn.activation.tanh
+
+    @nn.compact
+    def __call__(self, observations, actions):
+        # Estimate Q with a simple multi layer dense network.
+        x = jax.numpy.hstack((observations, actions))
+        for i, hidden_size in enumerate(self.hidden_sizes):
+            x = nn.Dense(
+                name=f'denselayer{i}',
+                features=hidden_size,
+            )(x)
+            x = self.activation(x)
+
+        x = nn.Dense(name=f'q', features=1,
+                     kernel_init=nn.initializers.zeros)(x)
+
+        return x
+
+
+class TrainState(flax.struct.PyTreeNode):
+    physics_constants: jax_dynamics.CoefficientsType = flax.struct.field(
+        pytree_node=False)
+
+    step: int | jax.Array = flax.struct.field(pytree_node=True)
+    substep: int | jax.Array = flax.struct.field(pytree_node=True)
+
+    params: flax.core.FrozenDict[str, typing.Any] = flax.struct.field(
+        pytree_node=True)
+
+    target_params: flax.core.FrozenDict[str, typing.Any] = flax.struct.field(
+        pytree_node=True)
+
+    pi_apply_fn: Callable = flax.struct.field(pytree_node=False)
+    q1_apply_fn: Callable = flax.struct.field(pytree_node=False)
+    q2_apply_fn: Callable = flax.struct.field(pytree_node=False)
+
+    pi_tx: optax.GradientTransformation = flax.struct.field(pytree_node=False)
+    pi_opt_state: optax.OptState = flax.struct.field(pytree_node=True)
+    q_tx: optax.GradientTransformation = flax.struct.field(pytree_node=False)
+    q_opt_state: optax.OptState = flax.struct.field(pytree_node=True)
+
+    alpha_tx: optax.GradientTransformation = flax.struct.field(
+        pytree_node=False)
+    alpha_opt_state: optax.OptState = flax.struct.field(pytree_node=True)
+
+    target_entropy: float = flax.struct.field(pytree_node=True)
+
+    mesh: Mesh = flax.struct.field(pytree_node=False)
+    sharding: NamedSharding = flax.struct.field(pytree_node=False)
+    replicated_sharding: NamedSharding = flax.struct.field(pytree_node=False)
+
+    replay_buffer: flashbax.buffers.trajectory_buffer.TrajectoryBuffer = flax.struct.field(
+        pytree_node=False)
+
+    def pi_apply(self,
+                 rng: PRNGKey,
+                 params,
+                 observation,
+                 deterministic: bool = False):
+        return self.pi_apply_fn({'params': params['pi']},
+                                physics.unwrap_angles(observation),
+                                deterministic=deterministic,
+                                rngs={'pi': rng})
+
+    def q1_apply(self, params, observation, action):
+        return self.q1_apply_fn({'params': params['q1']},
+                                physics.unwrap_angles(observation), action)
+
+    def q2_apply(self, params, observation, action):
+        return self.q2_apply_fn({'params': params['q2']},
+                                physics.unwrap_angles(observation), action)
+
+    def pi_apply_gradients(self, step, grads):
+        updates, new_pi_opt_state = self.pi_tx.update(grads, self.pi_opt_state,
+                                                      self.params)
+        new_params = optax.apply_updates(self.params, updates)
+
+        return self.replace(
+            step=step,
+            substep=jax.lax.select(step != self.step, 0, self.substep + 1),
+            params=new_params,
+            pi_opt_state=new_pi_opt_state,
+        )
+
+    def q_apply_gradients(self, step, grads):
+        updates, new_q_opt_state = self.q_tx.update(grads, self.q_opt_state,
+                                                    self.params)
+        new_params = optax.apply_updates(self.params, updates)
+
+        return self.replace(
+            step=step,
+            substep=jax.lax.select(step != self.step, 0, self.substep + 1),
+            params=new_params,
+            q_opt_state=new_q_opt_state,
+        )
+
+    def target_apply_gradients(self, step):
+        new_target_params = optax.incremental_update(self.params,
+                                                     self.target_params,
+                                                     1 - FLAGS.polyak)
+
+        return self.replace(
+            step=step,
+            substep=jax.lax.select(step != self.step, 0, self.substep + 1),
+            target_params=new_target_params,
+        )
+
+    def alpha_apply_gradients(self, step, grads):
+        updates, new_alpha_opt_state = self.alpha_tx.update(
+            grads, self.alpha_opt_state, self.params)
+        new_params = optax.apply_updates(self.params, updates)
+
+        return self.replace(
+            step=step,
+            substep=jax.lax.select(step != self.step, 0, self.substep + 1),
+            params=new_params,
+            alpha_opt_state=new_alpha_opt_state,
+        )
+
+    def update_step(self, step):
+        return self.replace(
+            step=step,
+            substep=jax.lax.select(step != self.step, 0, self.substep + 1),
+        )
+
+    @classmethod
+    def create(cls, *, physics_constants: jax_dynamics.CoefficientsType,
+               params, pi_tx, q_tx, alpha_tx, pi_apply_fn, q1_apply_fn,
+               q2_apply_fn, **kwargs):
+        """Creates a new instance with ``step=0`` and initialized ``opt_state``."""
+
+        pi_tx = optax.multi_transform(
+            {
+                'train': pi_tx,
+                'freeze': optax.set_to_zero()
+            },
+            param_labels=flax.traverse_util.path_aware_map(
+                lambda path, x: 'train'
+                if path[0] == 'pi' else 'freeze', params),
+        )
+        pi_opt_state = pi_tx.init(params)
+
+        q_tx = optax.multi_transform(
+            {
+                'train': q_tx,
+                'freeze': optax.set_to_zero()
+            },
+            param_labels=flax.traverse_util.path_aware_map(
+                lambda path, x: 'train'
+                if path[0] == 'q1' or path[0] == 'q2' else 'freeze', params),
+        )
+        q_opt_state = q_tx.init(params)
+
+        alpha_tx = optax.multi_transform(
+            {
+                'train': alpha_tx,
+                'freeze': optax.set_to_zero()
+            },
+            param_labels=flax.traverse_util.path_aware_map(
+                lambda path, x: 'train'
+                if path[0] == 'logalpha' else 'freeze', params),
+        )
+        alpha_opt_state = alpha_tx.init(params)
+
+        mesh = Mesh(
+            devices=mesh_utils.create_device_mesh(len(jax.devices())),
+            axis_names=('batch', ),
+        )
+        print('Devices:', jax.devices())
+        sharding = NamedSharding(mesh, PartitionSpec('batch'))
+        replicated_sharding = NamedSharding(mesh, PartitionSpec())
+
+        replay_buffer = experience_buffer.make_experience_buffer(
+            num_agents=FLAGS.num_agents,
+            sample_batch_size=FLAGS.batch_size,
+            length=FLAGS.replay_size)
+
+        result = cls(
+            physics_constants=physics_constants,
+            step=0,
+            substep=0,
+            params=params,
+            target_params=params,
+            q1_apply_fn=q1_apply_fn,
+            q2_apply_fn=q2_apply_fn,
+            pi_apply_fn=pi_apply_fn,
+            pi_tx=pi_tx,
+            pi_opt_state=pi_opt_state,
+            q_tx=q_tx,
+            q_opt_state=q_opt_state,
+            alpha_tx=alpha_tx,
+            alpha_opt_state=alpha_opt_state,
+            target_entropy=-physics.NUM_STATES,
+            mesh=mesh,
+            sharding=sharding,
+            replicated_sharding=replicated_sharding,
+            replay_buffer=replay_buffer,
+        )
+
+        return jax.device_put(result, replicated_sharding)
+
+
+def create_train_state(rng, physics_constants: jax_dynamics.CoefficientsType,
+                       q_learning_rate, pi_learning_rate):
+    """Creates initial `TrainState`."""
+    pi = SquashedGaussianMLPActor(action_space=physics.NUM_OUTPUTS)
+    # We want q1 and q2 to have different network architectures so they pick up differnet things.
+    q1 = MLPQFunction(activation=nn.activation.gelu, hidden_sizes=[128, 256])
+    q2 = MLPQFunction(activation=nn.activation.gelu, hidden_sizes=[256, 128])
+
+    @jax.jit
+    def init_params(rng):
+        pi_rng, q1_rng, q2_rng = jax.random.split(rng, num=3)
+
+        pi_params = pi.init(
+            pi_rng,
+            jax.numpy.ones([physics.NUM_UNWRAPPED_STATES]),
+        )['params']
+        q1_params = q1.init(
+            q1_rng,
+            jax.numpy.ones([physics.NUM_UNWRAPPED_STATES]),
+            jax.numpy.ones([physics.NUM_OUTPUTS]),
+        )['params']
+        q2_params = q2.init(
+            q2_rng,
+            jax.numpy.ones([physics.NUM_UNWRAPPED_STATES]),
+            jax.numpy.ones([physics.NUM_OUTPUTS]),
+        )['params']
+
+        if FLAGS.alpha < 0.0:
+            logalpha = 0.0
+        else:
+            logalpha = jax.numpy.log(FLAGS.alpha)
+
+        return {
+            'q1': q1_params,
+            'q2': q2_params,
+            'pi': pi_params,
+            'logalpha': logalpha,
+        }
+
+    pi_tx = optax.sgd(learning_rate=pi_learning_rate)
+    q_tx = optax.sgd(learning_rate=q_learning_rate)
+    alpha_tx = optax.sgd(learning_rate=q_learning_rate)
+
+    result = TrainState.create(
+        physics_constants=physics_constants,
+        params=init_params(rng),
+        pi_tx=pi_tx,
+        q_tx=q_tx,
+        alpha_tx=alpha_tx,
+        pi_apply_fn=pi.apply,
+        q1_apply_fn=q1.apply,
+        q2_apply_fn=q2.apply,
+    )
+
+    return result
+
+
+def create_learning_rate_fn(
+    base_learning_rate: float,
+    final_learning_rate: float,
+):
+    """Create learning rate schedule."""
+    warmup_fn = optax.linear_schedule(
+        init_value=base_learning_rate,
+        end_value=base_learning_rate,
+        transition_steps=FLAGS.warmup_steps,
+    )
+
+    cosine_epochs = max(FLAGS.steps - FLAGS.warmup_steps, 1)
+    cosine_fn = optax.cosine_decay_schedule(init_value=base_learning_rate,
+                                            decay_steps=cosine_epochs,
+                                            alpha=final_learning_rate)
+
+    schedule_fn = optax.join_schedules(
+        schedules=[warmup_fn, cosine_fn],
+        boundaries=[FLAGS.warmup_steps],
+    )
+    return schedule_fn
diff --git a/frc971/control_loops/swerve/velocity_controller/physics.py b/frc971/control_loops/swerve/velocity_controller/physics.py
new file mode 100644
index 0000000..9454f14
--- /dev/null
+++ b/frc971/control_loops/swerve/velocity_controller/physics.py
@@ -0,0 +1,69 @@
+import jax
+from functools import partial
+from frc971.control_loops.swerve import dynamics
+from absl import logging
+from frc971.control_loops.swerve import jax_dynamics
+
+
+@partial(jax.jit, static_argnums=[0])
+def xdot_physics(physics_constants: jax_dynamics.CoefficientsType, X, U):
+    A = jax.numpy.array([[0., 1.], [0., -36.85154548]])
+    B = jax.numpy.array([[0.], [56.08534375]])
+
+    return A @ X + B @ U
+
+
+def unwrap_angles(X):
+    return X
+
+
+@jax.jit
+def swerve_cost(X: jax.typing.ArrayLike, U: jax.typing.ArrayLike,
+                goal: jax.typing.ArrayLike):
+
+    Q = jax.numpy.array([[2.77777778, 0.], [0., 0.01]])
+    R = jax.numpy.array([[0.00694444]])
+
+    return (X).T @ Q @ (X) + U.T @ R @ U
+
+
+@partial(jax.jit, static_argnums=[0])
+def integrate_dynamics(physics_constants: jax_dynamics.CoefficientsType, X, U):
+    m = 2  # RK4 steps per interval
+    dt = 0.005 / m
+
+    def iteration(i, X):
+        weights = jax.numpy.array([[0.0, 0.5, 0.5, 1.0], [1.0, 2.0, 2.0, 1.0]])
+
+        def rk_iteration(i, val):
+            kx_previous, weighted_sum = val
+            kx = xdot_physics(physics_constants,
+                              X + dt * weights[0, i] * kx_previous, U)
+            weighted_sum += dt * weights[1, i] * kx / 6.0
+            return (kx, weighted_sum)
+
+        return jax.lax.fori_loop(lower=0,
+                                 upper=4,
+                                 body_fun=rk_iteration,
+                                 init_val=(X, X))[1]
+
+    return jax.lax.fori_loop(lower=0, upper=m, body_fun=iteration, init_val=X)
+
+
+state_cost = swerve_cost
+NUM_STATES = 2
+NUM_UNWRAPPED_STATES = 2
+NUM_OUTPUTS = 1
+ACTION_LIMIT = 10.0
+
+
+def random_states(rng, dimensions=None):
+    rng1, rng2 = jax.random.split(rng)
+
+    return jax.numpy.hstack(
+        (jax.random.uniform(rng1, (dimensions or FLAGS.num_agents, 1),
+                            minval=-1.0,
+                            maxval=1.0),
+         jax.random.uniform(rng2, (dimensions or FLAGS.num_agents, 1),
+                            minval=-10.0,
+                            maxval=10.0)))
diff --git a/frc971/control_loops/swerve/velocity_controller/train.py b/frc971/control_loops/swerve/velocity_controller/train.py
new file mode 100644
index 0000000..5002eee
--- /dev/null
+++ b/frc971/control_loops/swerve/velocity_controller/train.py
@@ -0,0 +1,511 @@
+import absl
+import time
+import collections
+from absl import logging
+import flax
+import matplotlib
+from matplotlib import pyplot
+from flax import linen as nn
+from flax.training import train_state
+from flax.training import checkpoints
+import jax
+import inspect
+import aim
+import jax.numpy as jnp
+import ml_collections
+import numpy as np
+import optax
+import numpy
+from frc971.control_loops.swerve import jax_dynamics
+from functools import partial
+import flashbax
+from jax.experimental.ode import odeint
+from frc971.control_loops.swerve import dynamics
+import orbax.checkpoint
+from frc971.control_loops.swerve.velocity_controller.model import *
+from frc971.control_loops.swerve.velocity_controller.physics import *
+
+numpy.set_printoptions(linewidth=200, )
+
+FLAGS = absl.flags.FLAGS
+
+absl.flags.DEFINE_integer(
+    'horizon',
+    default=25,
+    help='MPC horizon',
+)
+
+absl.flags.DEFINE_integer(
+    'start_steps',
+    default=10000,
+    help='Number of steps to randomly sample before using the policy',
+)
+
+absl.flags.DEFINE_integer(
+    'steps',
+    default=400000,
+    help='Number of steps to run and train the agent',
+)
+
+absl.flags.DEFINE_integer(
+    'warmup_steps',
+    default=300000,
+    help='Number of steps to warm up training',
+)
+
+absl.flags.DEFINE_float(
+    'gamma',
+    default=0.999,
+    help='Future discount.',
+)
+
+absl.flags.DEFINE_float(
+    'alpha',
+    default=0.2,
+    help='Entropy.  If negative, automatically solve for it.',
+)
+
+absl.flags.DEFINE_float(
+    'polyak',
+    default=0.995,
+    help='Time constant in polyak averaging for the target network.',
+)
+
+absl.flags.DEFINE_bool(
+    'debug_nan',
+    default=False,
+    help='If true, explode on any NaNs found, and print them.',
+)
+
+
+def random_actions(rng, dimensions=None):
+    """Produces a uniformly random action in the action space."""
+    return jax.random.uniform(
+        rng,
+        (dimensions or FLAGS.num_agents, physics.NUM_OUTPUTS),
+        minval=-physics.ACTION_LIMIT,
+        maxval=physics.ACTION_LIMIT,
+    )
+
+
+def save_checkpoint(state: TrainState, workdir: str):
+    """Saves a checkpoint in the workdir."""
+    # TODO(austin): use orbax directly.
+    step = int(state.step)
+    logging.info('Saving checkpoint step %d.', step)
+    checkpoints.save_checkpoint_multiprocess(workdir, state, step, keep=10)
+
+
+def restore_checkpoint(state: TrainState, workdir: str):
+    """Restores the latest checkpoint from the workdir."""
+    return checkpoints.restore_checkpoint(workdir, state)
+
+
+def has_nan(x):
+    if isinstance(x, jnp.ndarray):
+        return jnp.any(jnp.isnan(x))
+    else:
+        return jnp.any(
+            jax.numpy.array([has_nan(v)
+                             for v in jax.tree_util.tree_leaves(x)]))
+
+
+def print_nan(step, x):
+    if not FLAGS.debug_nan:
+        return
+
+    caller = inspect.getframeinfo(inspect.stack()[1][0])
+    # TODO(austin): It is helpful to sometimes start printing at a certain step.
+    jax.lax.cond(
+        has_nan(x), lambda: jax.debug.print(caller.filename +
+                                            ':{l} step {s} isnan(X) -> {x}',
+                                            l=caller.lineno,
+                                            s=step,
+                                            x=x), lambda: None)
+
+
+@jax.jit
+def compute_loss_q(state: TrainState, rng: PRNGKey, params, data: ArrayLike):
+    """Computes the Soft Actor-Critic loss for Q1 and Q2."""
+    observations1 = data['observations1']
+    actions = data['actions']
+    rewards = data['rewards']
+    observations2 = data['observations2']
+
+    # Compute the ending actions from the current network.
+    action2, logp_pi2, _, _ = state.pi_apply(rng=rng,
+                                             params=params,
+                                             observation=observations2)
+
+    # Compute target network Q values
+    q1_pi_target = state.q1_apply(state.target_params, observations2, action2)
+    q2_pi_target = state.q2_apply(state.target_params, observations2, action2)
+    q_pi_target = jax.numpy.minimum(q1_pi_target, q2_pi_target)
+
+    alpha = jax.numpy.exp(params['logalpha'])
+
+    # Now we can compute the Bellman backup
+    bellman_backup = jax.lax.stop_gradient(rewards + FLAGS.gamma *
+                                           (q_pi_target - alpha * logp_pi2))
+
+    # Compute the starting Q values from the Q network being optimized.
+    q1 = state.q1_apply(params, observations1, actions)
+    q2 = state.q2_apply(params, observations1, actions)
+
+    # Mean squared error loss against Bellman backup
+    q1_loss = ((q1 - bellman_backup)**2).mean()
+    q2_loss = ((q2 - bellman_backup)**2).mean()
+    return q1_loss + q2_loss
+
+
+@jax.jit
+def compute_batched_loss_q(state: TrainState, rng: PRNGKey, params,
+                           data: ArrayLike):
+
+    def bound_compute_loss_q(rng, data):
+        return compute_loss_q(state, rng, params, data)
+
+    return jax.vmap(bound_compute_loss_q)(
+        jax.random.split(rng, FLAGS.num_agents),
+        data,
+    ).mean()
+
+
+@jax.jit
+def compute_loss_pi(state: TrainState, rng: PRNGKey, params, data: ArrayLike):
+    """Computes the Soft Actor-Critic loss for pi."""
+    observations1 = data['observations1']
+    # TODO(austin): We've got differentiable policy and differentiable physics.  Can we use those here?  Have Q learn the future, not the current step?
+
+    # Compute the action
+    pi, logp_pi, _, _ = state.pi_apply(rng=rng,
+                                       params=params,
+                                       observation=observations1)
+    q1_pi = state.q1_apply(jax.lax.stop_gradient(params), observations1, pi)
+    q2_pi = state.q2_apply(jax.lax.stop_gradient(params), observations1, pi)
+
+    # And compute the Q of that action.
+    q_pi = jax.numpy.minimum(q1_pi, q2_pi)
+
+    alpha = jax.lax.stop_gradient(jax.numpy.exp(params['logalpha']))
+
+    # Compute the entropy-regularized policy loss
+    return (alpha * logp_pi - q_pi).mean()
+
+
+@jax.jit
+def compute_batched_loss_pi(state: TrainState, rng: PRNGKey, params,
+                            data: ArrayLike):
+
+    def bound_compute_loss_pi(rng, data):
+        return compute_loss_pi(state, rng, params, data)
+
+    return jax.vmap(bound_compute_loss_pi)(
+        jax.random.split(rng, FLAGS.num_agents),
+        data,
+    ).mean()
+
+
+@jax.jit
+def compute_loss_alpha(state: TrainState, rng: PRNGKey, params,
+                       data: ArrayLike):
+    """Computes the Soft Actor-Critic loss for alpha."""
+    observations1 = data['observations1']
+    pi, logp_pi, _, _ = jax.lax.stop_gradient(
+        state.pi_apply(rng=rng, params=params, observation=observations1))
+
+    return (-jax.numpy.exp(params['logalpha']) *
+            (logp_pi + state.target_entropy)).mean()
+
+
+@jax.jit
+def compute_batched_loss_alpha(state: TrainState, rng: PRNGKey, params,
+                               data: ArrayLike):
+
+    def bound_compute_loss_alpha(rng, data):
+        return compute_loss_alpha(state, rng, params, data)
+
+    return jax.vmap(bound_compute_loss_alpha)(
+        jax.random.split(rng, FLAGS.num_agents),
+        data,
+    ).mean()
+
+
+@jax.jit
+def train_step(state: TrainState, data, action_data, update_rng: PRNGKey,
+               step: int) -> TrainState:
+    """Updates the parameters for Q, Pi, target Q, and alpha."""
+    update_rng, q_grad_rng = jax.random.split(update_rng)
+    print_nan(step, data)
+
+    # Update Q
+    q_grad_fn = jax.value_and_grad(
+        lambda params: compute_batched_loss_q(state, q_grad_rng, params, data))
+    q_loss, q_grads = q_grad_fn(state.params)
+    print_nan(step, q_loss)
+    print_nan(step, q_grads)
+
+    state = state.q_apply_gradients(step=step, grads=q_grads)
+
+    update_rng, pi_grad_rng = jax.random.split(update_rng)
+
+    # Update pi
+    pi_grad_fn = jax.value_and_grad(lambda params: compute_batched_loss_pi(
+        state, pi_grad_rng, params, action_data))
+    pi_loss, pi_grads = pi_grad_fn(state.params)
+
+    print_nan(step, pi_loss)
+    print_nan(step, pi_grads)
+
+    state = state.pi_apply_gradients(step=step, grads=pi_grads)
+
+    update_rng, alpha_grad_rng = jax.random.split(update_rng)
+
+    if FLAGS.alpha < 0.0:
+        # Update alpha
+        alpha_grad_fn = jax.value_and_grad(
+            lambda params: compute_batched_loss_alpha(state, alpha_grad_rng,
+                                                      params, data))
+        alpha_loss, alpha_grads = alpha_grad_fn(state.params)
+        print_nan(step, alpha_loss)
+        print_nan(step, alpha_grads)
+        state = state.alpha_apply_gradients(step=step, grads=alpha_grads)
+    else:
+        alpha_loss = 0.0
+
+    return state, q_loss, pi_loss, alpha_loss
+
+
+@jax.jit
+def collect_experience(state: TrainState, replay_buffer_state, R: ArrayLike,
+                       step_rng: PRNGKey, step):
+    """Collects experience by simulating."""
+    pi_rng = jax.random.fold_in(step_rng, step)
+    pi_rng, initialization_rng = jax.random.split(pi_rng)
+
+    observation = jax.lax.with_sharding_constraint(
+        physics.random_states(initialization_rng, FLAGS.num_agents),
+        state.sharding)
+
+    def loop(i, val):
+        """Runs 1 step of our simulation."""
+        observation, pi_rng, replay_buffer_state = val
+        pi_rng, action_rng = jax.random.split(pi_rng)
+        logging.info('Observation shape: %s', observation.shape)
+
+        def true_fn(i):
+            # We are at the beginning of the process, pick a random action.
+            return random_actions(action_rng, FLAGS.num_agents)
+
+        def false_fn(i):
+            # We are past the beginning of the process, use the trained network.
+            pi_action, logp_pi, std, random_sample = state.pi_apply(
+                rng=action_rng,
+                params=state.params,
+                observation=observation,
+                deterministic=False)
+            return pi_action
+
+        pi_action = jax.lax.cond(
+            step <= FLAGS.start_steps,
+            true_fn,
+            false_fn,
+            i,
+        )
+
+        # Compute the destination state.
+        observation2 = jax.vmap(lambda o, pi: physics.integrate_dynamics(
+            state.physics_constants, o, pi),
+                                in_axes=(0, 0))(observation, pi_action)
+
+        # Soft Actor-Critic is designed to maximize reward.  LQR minimizes
+        # cost.  There is nothing which assumes anything about the sign of
+        # the reward, so use the negative of the cost.
+        reward = (-jax.vmap(state_cost)(observation2, pi_action, R)).reshape(
+            (FLAGS.num_agents, 1))
+
+        logging.info('Observation shape: %s', observation.shape)
+        logging.info('Observation2 shape: %s', observation2.shape)
+        logging.info('Action shape: %s', pi_action.shape)
+        logging.info('Action shape: %s', pi_action.shape)
+        replay_buffer_state = state.replay_buffer.add(
+            replay_buffer_state, {
+                'observations1': observation,
+                'observations2': observation2,
+                'actions': pi_action,
+                'rewards': reward,
+            })
+
+        return observation2, pi_rng, replay_buffer_state
+
+    # Run 1 horizon of simulation
+    final_observation, final_pi_rng, final_replay_buffer_state = jax.lax.fori_loop(
+        0, FLAGS.horizon + 1, loop, (observation, pi_rng, replay_buffer_state))
+
+    return state, final_replay_buffer_state
+
+
+@jax.jit
+def update_gradients(rng: PRNGKey, state: TrainState, replay_buffer_state,
+                     step: int):
+    rng, sample_rng = jax.random.split(rng)
+
+    action_data = state.replay_buffer.sample(replay_buffer_state, sample_rng)
+
+    def update_iteration(i, val):
+        rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, action_data = val
+        rng, sample_rng, update_rng = jax.random.split(rng, 3)
+
+        batch = state.replay_buffer.sample(replay_buffer_state, sample_rng)
+
+        print_nan(i, replay_buffer_state)
+        print_nan(i, batch)
+
+        state, q_loss, pi_loss, alpha_loss = train_step(
+            state,
+            data=batch.experience,
+            action_data=batch.experience,
+            update_rng=update_rng,
+            step=i)
+
+        return rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, action_data
+
+    rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, action_data = jax.lax.fori_loop(
+        step, step + FLAGS.horizon + 1, update_iteration,
+        (rng, state, 0.0, 0.0, 0.0, replay_buffer_state, action_data))
+
+    state = state.target_apply_gradients(step=state.step)
+
+    return rng, state, q_loss, pi_loss, alpha_loss
+
+
+def train(
+    workdir: str, physics_constants: jax_dynamics.CoefficientsType
+) -> train_state.TrainState:
+    """Trains a Soft Actor-Critic controller."""
+    rng = jax.random.key(0)
+    rng, r_rng = jax.random.split(rng)
+
+    run = aim.Run(repo='aim://127.0.0.1:53800')
+
+    run['hparams'] = {
+        'q_learning_rate': FLAGS.q_learning_rate,
+        'pi_learning_rate': FLAGS.pi_learning_rate,
+        'batch_size': FLAGS.batch_size,
+        'horizon': FLAGS.horizon,
+        'warmup_steps': FLAGS.warmup_steps,
+    }
+
+    # Setup TrainState
+    rng, init_rng = jax.random.split(rng)
+    q_learning_rate = create_learning_rate_fn(
+        base_learning_rate=FLAGS.q_learning_rate,
+        final_learning_rate=FLAGS.final_q_learning_rate)
+    pi_learning_rate = create_learning_rate_fn(
+        base_learning_rate=FLAGS.pi_learning_rate,
+        final_learning_rate=FLAGS.final_pi_learning_rate)
+    state = create_train_state(
+        init_rng,
+        physics_constants,
+        q_learning_rate=q_learning_rate,
+        pi_learning_rate=pi_learning_rate,
+    )
+    state = restore_checkpoint(state, workdir)
+
+    state_sharding = nn.get_sharding(state, state.mesh)
+    logging.info(state_sharding)
+
+    # TODO(austin): Let the goal change.
+    R = jax.numpy.hstack((
+        jax.random.uniform(rng, (FLAGS.num_agents, 1), minval=0.9, maxval=1.1),
+        jax.random.uniform(rng, (FLAGS.num_agents, 1), minval=-0.1,
+                           maxval=0.1),
+        jax.numpy.zeros((FLAGS.num_agents, 1)),
+    ))
+
+    replay_buffer_state = state.replay_buffer.init({
+        'observations1':
+        jax.numpy.zeros((physics.NUM_STATES, )),
+        'observations2':
+        jax.numpy.zeros((physics.NUM_STATES, )),
+        'actions':
+        jax.numpy.zeros((physics.NUM_OUTPUTS, )),
+        'rewards':
+        jax.numpy.zeros((1, )),
+    })
+
+    replay_buffer_state_sharding = nn.get_sharding(replay_buffer_state,
+                                                   state.mesh)
+    logging.info(replay_buffer_state_sharding)
+
+    logging.info('Shape: %s, Sharding %s',
+                 replay_buffer_state.experience['observations1'].shape,
+                 replay_buffer_state.experience['observations1'].sharding)
+    logging.info('Shape: %s, Sharding %s',
+                 replay_buffer_state.experience['observations2'].shape,
+                 replay_buffer_state.experience['observations2'].sharding)
+    logging.info('Shape: %s, Sharding %s',
+                 replay_buffer_state.experience['actions'].shape,
+                 replay_buffer_state.experience['actions'].sharding)
+    logging.info('Shape: %s, Sharding %s',
+                 replay_buffer_state.experience['rewards'].shape,
+                 replay_buffer_state.experience['rewards'].sharding)
+
+    # Number of gradients to accumulate before doing decent.
+    update_after = FLAGS.batch_size // FLAGS.num_agents
+
+    @partial(jax.jit, donate_argnums=(1, 2))
+    def train_loop(state: TrainState, replay_buffer_state, rng: PRNGKey,
+                   step: int):
+        rng, step_rng = jax.random.split(rng)
+        # Collect experience
+        state, replay_buffer_state = collect_experience(
+            state,
+            replay_buffer_state,
+            R,
+            step_rng,
+            step,
+        )
+
+        def nop(rng, state, replay_buffer_state, step):
+            return rng, state.update_step(step=step), 0.0, 0.0, 0.0
+
+        # Train
+        rng, state, q_loss, pi_loss, alpha_loss = jax.lax.cond(
+            step >= update_after, update_gradients, nop, rng, state,
+            replay_buffer_state, step)
+
+        return state, replay_buffer_state, rng, q_loss, pi_loss, alpha_loss
+
+    for step in range(0, FLAGS.steps, FLAGS.horizon):
+        state, replay_buffer_state, rng, q_loss, pi_loss, alpha_loss = train_loop(
+            state, replay_buffer_state, rng, step)
+
+        if FLAGS.debug_nan and has_nan(state.params):
+            logging.fatal('Nan params, aborting')
+
+        logging.info(
+            'Step %s: q_loss=%s, pi_loss=%s, alpha_loss=%s, q_learning_rate=%s, pi_learning_rate=%s, alpha=%s',
+            step,
+            q_loss,
+            pi_loss,
+            alpha_loss,
+            q_learning_rate(step),
+            pi_learning_rate(step),
+            jax.numpy.exp(state.params['logalpha']),
+        )
+
+        run.track(
+            {
+                'q_loss': float(q_loss),
+                'pi_loss': float(pi_loss),
+                'alpha_loss': float(alpha_loss),
+                'alpha': float(jax.numpy.exp(state.params['logalpha']))
+            },
+            step=step)
+
+        if step % 1000 == 0 and step > update_after:
+            # TODO(austin): Simulate a rollout and accumulate the reward.  How good are we doing?
+            save_checkpoint(state, workdir)
+
+    return state