Add code to train a simple turret controller

This uses Soft Actor-Critic with automatic temperature adjustment to
train a controller for last year's turret.  This is a stepping step to
training a more complicated turret for the swerve.  I need to still
parallelize it, play with the hyper-parameters, and teach it how to go
to a nonzero goal.

Change-Id: I1357b5fbf8549acac4ee0b94ef8f2636867c28ad
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/physics_test.py b/frc971/control_loops/swerve/physics_test.py
index 6e7ddf4..a45bf46 100644
--- a/frc971/control_loops/swerve/physics_test.py
+++ b/frc971/control_loops/swerve/physics_test.py
@@ -59,6 +59,7 @@
         Xdot = self.position_swerve_full_dynamics(X, U)
         Xdot_jax = jax_dynamics.full_dynamics(self.coefficients, X[:, 0], U[:,
                                                                             0])
+        self.assertEqual(Xdot.shape[0], Xdot_jax.shape[0])
 
         self.assertLess(
             numpy.linalg.norm(Xdot[:, 0] - Xdot_jax),
@@ -71,6 +72,9 @@
         velocity_physics_jax = jax_dynamics.velocity_dynamics(
             self.coefficients, X_velocity[:, 0], U[:, 0])
 
+        self.assertEqual(velocity_physics.shape[0],
+                         velocity_physics_jax.shape[0])
+
         self.assertLess(
             numpy.linalg.norm(velocity_physics[:, 0] - velocity_physics_jax),
             2e-8,
diff --git a/frc971/control_loops/swerve/velocity_controller/BUILD b/frc971/control_loops/swerve/velocity_controller/BUILD
index 5e06420..f0e199b 100644
--- a/frc971/control_loops/swerve/velocity_controller/BUILD
+++ b/frc971/control_loops/swerve/velocity_controller/BUILD
@@ -22,3 +22,49 @@
         "@pip//jax",
     ],
 )
+
+py_binary(
+    name = "main",
+    srcs = [
+        "main.py",
+        "model.py",
+        "physics.py",
+        "train.py",
+    ],
+    deps = [
+        ":experience_buffer",
+        "//frc971/control_loops/swerve:jax_dynamics",
+        "@pip//absl_py",
+        "@pip//aim",
+        "@pip//clu",
+        "@pip//flashbax",
+        "@pip//flax",
+        "@pip//jax",
+        "@pip//jaxtyping",
+        "@pip//matplotlib",
+        "@pip//numpy",
+        "@pip//tensorflow",
+    ],
+)
+
+py_binary(
+    name = "lqr_plot",
+    srcs = [
+        "lqr_plot.py",
+        "model.py",
+        "physics.py",
+    ],
+    deps = [
+        ":experience_buffer",
+        "//frc971/control_loops/swerve:jax_dynamics",
+        "@pip//absl_py",
+        "@pip//flashbax",
+        "@pip//flax",
+        "@pip//jax",
+        "@pip//jaxtyping",
+        "@pip//matplotlib",
+        "@pip//numpy",
+        "@pip//pygobject",
+        "@pip//tensorflow",
+    ],
+)
diff --git a/frc971/control_loops/swerve/velocity_controller/lqr_plot.py b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
new file mode 100644
index 0000000..30f5630
--- /dev/null
+++ b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
@@ -0,0 +1,489 @@
+#!/usr/bin/env python3
+
+import os
+
+# Setup JAX to run on the CPU
+os.environ['XLA_FLAGS'] = ' '.join([
+    # Teach it where to find CUDA
+    '--xla_gpu_cuda_data_dir=/usr/lib/cuda',
+    # Use up to 20 cores
+    '--xla_force_host_platform_device_count=20',
+    # Dump XLA to /tmp/foo to aid debugging
+    '--xla_dump_to=/tmp/foo',
+    '--xla_gpu_enable_command_buffer='
+])
+
+os.environ['JAX_PLATFORMS'] = 'cpu'
+# Don't pre-allocate memory
+os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
+
+import absl
+from absl import logging
+from matplotlib.animation import FuncAnimation
+import matplotlib
+import numpy
+import scipy
+import time
+
+matplotlib.use("gtk3agg")
+
+from matplotlib import pylab
+from matplotlib import pyplot
+from flax.training import checkpoints
+from frc971.control_loops.python import controls
+import tensorflow as tf
+import threading
+import collections
+
+import jax
+
+jax._src.deprecations.accelerate('tracer-hash')
+jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
+jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
+jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
+
+from frc971.control_loops.swerve.velocity_controller.model import *
+from frc971.control_loops.swerve.velocity_controller.physics import *
+
+# Container for the data.
+Data = collections.namedtuple('Data', [
+    't', 'X', 'X_lqr', 'U', 'U_lqr', 'cost', 'cost_lqr', 'q1_grid', 'q2_grid',
+    'q_grid', 'target_q_grid', 'lqr_grid', 'pi_grid_U', 'lqr_grid_U', 'grid_X',
+    'grid_Y', 'reward', 'reward_lqr', 'step'
+])
+
+FLAGS = absl.flags.FLAGS
+
+absl.flags.DEFINE_string('workdir', None, 'Directory to store model data.')
+
+absl.flags.DEFINE_integer(
+    'horizon',
+    default=100,
+    help='Horizon to simulate',
+)
+
+absl.flags.DEFINE_float(
+    'alpha',
+    default=0.2,
+    help='Entropy.  If negative, automatically solve for it.',
+)
+
+numpy.set_printoptions(linewidth=200, )
+
+
+def restore_checkpoint(state: TrainState, workdir: str):
+    return checkpoints.restore_checkpoint(workdir, state)
+
+
+# Hard-coded simulation parameters for the turret.
+dt = 0.005
+A = numpy.matrix([[1., 0.00456639], [0., 0.83172142]])
+B = numpy.matrix([[0.00065992], [0.25610763]])
+
+Q = numpy.matrix([[2.77777778, 0.], [0., 0.01]])
+R = numpy.matrix([[0.00694444]])
+
+# Compute the optimal LQR cost + controller.
+F, P = controls.dlqr(A, B, Q, R, optimal_cost_function=True)
+
+
+def generate_data(step=None):
+    grid_X = numpy.arange(-1, 1, 0.1)
+    grid_Y = numpy.arange(-10, 10, 0.1)
+    grid_X, grid_Y = numpy.meshgrid(grid_X, grid_Y)
+    grid_X = jax.numpy.array(grid_X)
+    grid_Y = jax.numpy.array(grid_Y)
+    # Load the training state.
+    physics_constants = jax_dynamics.Coefficients()
+
+    rng = jax.random.key(0)
+    rng, init_rng = jax.random.split(rng)
+
+    state = create_train_state(init_rng, physics_constants,
+                               FLAGS.q_learning_rate, FLAGS.pi_learning_rate)
+
+    state = restore_checkpoint(state, FLAGS.workdir)
+    if step is not None and state.step == step:
+        return None
+
+    print('F:', F)
+    print('P:', P)
+
+    X = jax.numpy.array([1.0, 0.0])
+    X_lqr = X.copy()
+    goal = jax.numpy.array([0.0, 0.0])
+
+    logging.info('X: %s', X)
+    logging.info('goal: %s', goal)
+    logging.debug('params: %s', state.params)
+
+    # Compute the various cost surfaces for plotting.
+    def compute_q1(X, Y):
+        return state.q1_apply(
+            state.params,
+            unwrap_angles(jax.numpy.array([X, Y])),
+            jax.numpy.array([0.]),
+        )[0]
+
+    def compute_q2(X, Y):
+        return state.q2_apply(
+            state.params,
+            unwrap_angles(jax.numpy.array([X, Y])),
+            jax.numpy.array([0.]),
+        )[0]
+
+    def lqr_cost(X, Y):
+        x = jax.numpy.array([X, Y])
+        return -x.T @ jax.numpy.array(P) @ x
+
+    def compute_q(params, x, y):
+        X = unwrap_angles(jax.numpy.array([x, y]))
+
+        return jax.numpy.minimum(
+            state.q1_apply(
+                params,
+                X,
+                jax.numpy.array([0.]),
+            )[0],
+            state.q2_apply(
+                params,
+                X,
+                jax.numpy.array([0.]),
+            )[0])
+
+    cost_grid1 = jax.vmap(jax.vmap(compute_q1))(grid_X, grid_Y)
+    cost_grid2 = jax.vmap(jax.vmap(compute_q2))(grid_X, grid_Y)
+    cost_grid = jax.vmap(jax.vmap(lambda x, y: compute_q(state.params, x, y)))(
+        grid_X, grid_Y)
+    target_cost_grid = jax.vmap(
+        jax.vmap(lambda x, y: compute_q(state.target_params, x, y)))(grid_X,
+                                                                     grid_Y)
+    lqr_cost_grid = jax.vmap(jax.vmap(lqr_cost))(grid_X, grid_Y)
+
+    # TODO(austin): Stuff to figure out:
+    # 3: Make it converge faster.  Use both GPUs better?
+    # 4: Can we feed in a reference position and get it to learn to stabilize there?
+
+    # Now compute the two controller surfaces.
+    def compute_lqr_U(X, Y):
+        x = jax.numpy.array([X, Y])
+        return (-jax.numpy.array(F.reshape((2, ))) @ x)[0]
+
+    def compute_pi_U(X, Y):
+        x = jax.numpy.array([X, Y])
+        U, _, _, _ = state.pi_apply(rng,
+                                    state.params,
+                                    observation=unwrap_angles(x),
+                                    deterministic=True)
+        return U[0]
+
+    lqr_cost_U = jax.vmap(jax.vmap(compute_lqr_U))(grid_X, grid_Y)
+    pi_cost_U = jax.vmap(jax.vmap(compute_pi_U))(grid_X, grid_Y)
+
+    logging.info('Finished cost')
+
+    # Now simulate the robot, accumulating up things to plot.
+    def loop(i, val):
+        X, X_lqr, data, params = val
+        t = data.t.at[i].set(i * dt)
+
+        U, _, _, _ = state.pi_apply(rng,
+                                    params,
+                                    observation=unwrap_angles(X),
+                                    deterministic=True)
+        U_lqr = F @ (goal - X_lqr)
+
+        cost = jax.numpy.minimum(state.q1_apply(params, unwrap_angles(X), U),
+                                 state.q2_apply(params, unwrap_angles(X), U))
+
+        U_plot = data.U.at[i, :].set(U)
+        U_lqr_plot = data.U_lqr.at[i, :].set(U_lqr)
+        X_plot = data.X.at[i, :].set(X)
+        X_lqr_plot = data.X_lqr.at[i, :].set(X_lqr)
+        cost_plot = data.cost.at[i, :].set(cost)
+        cost_lqr_plot = data.cost_lqr.at[i, :].set(lqr_cost(*X_lqr))
+
+        X = A @ X + B @ U
+        X_lqr = A @ X_lqr + B @ U_lqr
+
+        reward = data.reward - physics.state_cost(X, U, goal)
+        reward_lqr = data.reward_lqr - physics.state_cost(X_lqr, U_lqr, goal)
+
+        return X, X_lqr, data._replace(
+            t=t,
+            U=U_plot,
+            U_lqr=U_lqr_plot,
+            X=X_plot,
+            X_lqr=X_lqr_plot,
+            cost=cost_plot,
+            cost_lqr=cost_lqr_plot,
+            reward=reward,
+            reward_lqr=reward_lqr,
+        ), params
+
+    # Do it.
+    @jax.jit
+    def integrate(data, X, X_lqr, params):
+        return jax.lax.fori_loop(0, FLAGS.horizon, loop,
+                                 (X, X_lqr, data, params))
+
+    X, X_lqr, data, params = integrate(
+        Data(
+            t=jax.numpy.zeros((FLAGS.horizon, )),
+            X=jax.numpy.zeros((FLAGS.horizon, physics.NUM_STATES)),
+            X_lqr=jax.numpy.zeros((FLAGS.horizon, physics.NUM_STATES)),
+            U=jax.numpy.zeros((FLAGS.horizon, physics.NUM_OUTPUTS)),
+            U_lqr=jax.numpy.zeros((FLAGS.horizon, physics.NUM_OUTPUTS)),
+            cost=jax.numpy.zeros((FLAGS.horizon, 1)),
+            cost_lqr=jax.numpy.zeros((FLAGS.horizon, 1)),
+            q1_grid=cost_grid1,
+            q2_grid=cost_grid2,
+            q_grid=cost_grid,
+            target_q_grid=target_cost_grid,
+            lqr_grid=lqr_cost_grid,
+            pi_grid_U=pi_cost_U,
+            lqr_grid_U=lqr_cost_U,
+            grid_X=grid_X,
+            grid_Y=grid_Y,
+            reward=0.0,
+            reward_lqr=0.0,
+            step=state.step,
+        ), X, X_lqr, state.params)
+
+    logging.info('Finished integrating, reward of %f, lqr reward of %f',
+                 data.reward, data.reward_lqr)
+
+    # Convert back to numpy for plotting.
+    return Data(
+        t=numpy.array(data.t),
+        X=numpy.array(data.X),
+        X_lqr=numpy.array(data.X_lqr),
+        U=numpy.array(data.U),
+        U_lqr=numpy.array(data.U_lqr),
+        cost=numpy.array(data.cost),
+        cost_lqr=numpy.array(data.cost_lqr),
+        q1_grid=numpy.array(data.q1_grid),
+        q2_grid=numpy.array(data.q2_grid),
+        q_grid=numpy.array(data.q_grid),
+        target_q_grid=numpy.array(data.target_q_grid),
+        lqr_grid=numpy.array(data.lqr_grid),
+        pi_grid_U=numpy.array(data.pi_grid_U),
+        lqr_grid_U=numpy.array(data.lqr_grid_U),
+        grid_X=numpy.array(data.grid_X),
+        grid_Y=numpy.array(data.grid_Y),
+        reward=float(data.reward),
+        reward_lqr=float(data.reward_lqr),
+        step=data.step,
+    )
+
+
+class Plotter(object):
+
+    def __init__(self, data):
+        # Make all the plots and axis.
+        self.fig0, self.axs0 = pylab.subplots(3)
+        self.fig0.supxlabel('Seconds')
+
+        self.x, = self.axs0[0].plot([], [], label="x")
+        self.v, = self.axs0[0].plot([], [], label="v")
+        self.x_lqr, = self.axs0[0].plot([], [], label="x_lqr")
+        self.v_lqr, = self.axs0[0].plot([], [], label="v_lqr")
+
+        self.axs0[0].set_ylabel('Velocity')
+        self.axs0[0].legend()
+
+        self.uaxis, = self.axs0[1].plot([], [], label="U")
+        self.uaxis_lqr, = self.axs0[1].plot([], [], label="U_lqr")
+
+        self.axs0[1].set_ylabel('Amps')
+        self.axs0[1].legend()
+
+        self.costaxis, = self.axs0[2].plot([], [], label="cost")
+        self.costlqraxis, = self.axs0[2].plot([], [], label="cost lqr")
+        self.axs0[2].set_ylabel('Cost')
+        self.axs0[2].legend()
+
+        self.costfig = pyplot.figure(figsize=pyplot.figaspect(0.5))
+        self.cost3dax = [
+            self.costfig.add_subplot(2, 3, 1, projection='3d'),
+            self.costfig.add_subplot(2, 3, 2, projection='3d'),
+            self.costfig.add_subplot(2, 3, 3, projection='3d'),
+            self.costfig.add_subplot(2, 3, 4, projection='3d'),
+            self.costfig.add_subplot(2, 3, 5, projection='3d'),
+            self.costfig.add_subplot(2, 3, 6, projection='3d'),
+        ]
+
+        self.Ufig = pyplot.figure(figsize=pyplot.figaspect(0.5))
+        self.Uax = [
+            self.Ufig.add_subplot(1, 3, 1, projection='3d'),
+            self.Ufig.add_subplot(1, 3, 2, projection='3d'),
+            self.Ufig.add_subplot(1, 3, 3, projection='3d'),
+        ]
+
+        self.last_trajectory_step = 0
+        self.last_cost_step = 0
+        self.last_U_step = 0
+
+    def update_trajectory_plot(self, data):
+        if data.step == self.last_trajectory_step:
+            return
+        self.last_trajectory_step = data.step
+        logging.info('Updating trajectory plots')
+
+        # Put data in the trajectory plots.
+        self.x.set_data(data.t, data.X[:, 0])
+        self.v.set_data(data.t, data.X[:, 1])
+        self.x_lqr.set_data(data.t, data.X_lqr[:, 0])
+        self.v_lqr.set_data(data.t, data.X_lqr[:, 1])
+
+        self.uaxis.set_data(data.t, data.U[:, 0])
+        self.uaxis_lqr.set_data(data.t, data.U_lqr[:, 0])
+        self.costaxis.set_data(data.t, data.cost[:, 0] - data.cost[-1, 0])
+        self.costlqraxis.set_data(data.t, data.cost_lqr[:, 0])
+
+        self.axs0[0].relim()
+        self.axs0[0].autoscale_view()
+
+        self.axs0[1].relim()
+        self.axs0[1].autoscale_view()
+
+        self.axs0[2].relim()
+        self.axs0[2].autoscale_view()
+
+        return self.x, self.v, self.uaxis, self.costaxis, self.costlqraxis
+
+    def update_cost_plot(self, data):
+        if data.step == self.last_cost_step:
+            return
+        logging.info('Updating cost plots')
+        self.last_cost_step = data.step
+        # Put data in the cost plots.
+        if hasattr(self, 'costsurf'):
+            for surf in self.costsurf:
+                surf.remove()
+
+        plots = [
+            (data.q1_grid, 'q1'),
+            (data.q2_grid, 'q2'),
+            (data.q_grid, 'q'),
+            (data.target_q_grid, 'target q'),
+            (data.lqr_grid, 'lqr'),
+            (data.q_grid - data.q_grid.max() - data.lqr_grid, 'error'),
+        ]
+
+        self.costsurf = [
+            self.cost3dax[i].plot_surface(
+                data.grid_X,
+                data.grid_Y,
+                Z,
+                cmap="magma",
+                label=label,
+            ) for i, (Z, label) in enumerate(plots)
+        ]
+
+        for axis in self.cost3dax:
+            axis.legend()
+            axis.relim()
+            axis.autoscale_view()
+
+        return self.costsurf
+
+    def update_U_plot(self, data):
+        if data.step == self.last_U_step:
+            return
+        self.last_U_step = data.step
+        logging.info('Updating U plots')
+        # Put data in the controller plots.
+        if hasattr(self, 'Usurf'):
+            for surf in self.Usurf:
+                surf.remove()
+
+        plots = [
+            (data.lqr_grid_U, 'lqr'),
+            (data.pi_grid_U, 'pi'),
+            ((data.lqr_grid_U - data.pi_grid_U), 'error'),
+        ]
+
+        self.Usurf = [
+            self.Uax[i].plot_surface(
+                data.grid_X,
+                data.grid_Y,
+                Z,
+                cmap="magma",
+                label=label,
+            ) for i, (Z, label) in enumerate(plots)
+        ]
+
+        for axis in self.Uax:
+            axis.legend()
+            axis.relim()
+            axis.autoscale_view()
+
+        return self.Usurf
+
+
+def main(argv):
+    if len(argv) > 1:
+        raise absl.app.UsageError('Too many command-line arguments.')
+
+    tf.config.experimental.set_visible_devices([], 'GPU')
+
+    lock = threading.Lock()
+
+    # Load data.
+    data = generate_data()
+
+    plotter = Plotter(data)
+
+    # Event for shutting down the thread.
+    shutdown = threading.Event()
+
+    # Thread to grab new data periodically.
+    def do_update():
+        while True:
+            nonlocal data
+
+            my_data = generate_data(data.step)
+
+            if my_data is not None:
+                with lock:
+                    data = my_data
+
+            if shutdown.wait(timeout=3):
+                return
+
+    update_thread = threading.Thread(target=do_update)
+    update_thread.start()
+
+    # Now, update each of the plots every second with the new data.
+    def update0(frame):
+        with lock:
+            my_data = data
+
+        return plotter.update_trajectory_plot(my_data)
+
+    def update1(frame):
+        with lock:
+            my_data = data
+
+        return plotter.update_cost_plot(my_data)
+
+    def update2(frame):
+        with lock:
+            my_data = data
+
+        return plotter.update_U_plot(my_data)
+
+    animation0 = FuncAnimation(plotter.fig0, update0, interval=1000)
+    animation1 = FuncAnimation(plotter.costfig, update1, interval=1000)
+    animation2 = FuncAnimation(plotter.Ufig, update2, interval=1000)
+
+    pyplot.show()
+
+    shutdown.set()
+    update_thread.join()
+
+
+if __name__ == '__main__':
+    absl.flags.mark_flags_as_required(['workdir'])
+    absl.app.run(main)
diff --git a/frc971/control_loops/swerve/velocity_controller/main.py b/frc971/control_loops/swerve/velocity_controller/main.py
new file mode 100644
index 0000000..7ef687c
--- /dev/null
+++ b/frc971/control_loops/swerve/velocity_controller/main.py
@@ -0,0 +1,72 @@
+import os
+
+# Setup XLA first.
+os.environ['XLA_FLAGS'] = ' '.join([
+    # Teach it where to find CUDA
+    '--xla_gpu_cuda_data_dir=/usr/lib/cuda',
+    # Use up to 20 cores
+    #'--xla_force_host_platform_device_count=6',
+    # Dump XLA to /tmp/foo to aid debugging
+    #'--xla_dump_to=/tmp/foo',
+    #'--xla_gpu_enable_command_buffer='
+    # Dump sharding
+    #"--xla_dump_to=/tmp/foo",
+    #"--xla_dump_hlo_pass_re=spmd|propagation"
+])
+os.environ['JAX_PLATFORMS'] = 'cuda,cpu'
+os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'true'
+os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.50'
+
+from absl import app
+from absl import flags
+from absl import logging
+from clu import platform
+import jax
+import tensorflow as tf
+from frc971.control_loops.swerve import jax_dynamics
+
+jax._src.deprecations.accelerate('tracer-hash')
+# Enable the compilation cache
+jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
+jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
+jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
+jax.config.update('jax_threefry_partitionable', True)
+
+import train
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string('workdir', None, 'Directory to store model data.')
+
+
+def main(argv):
+    if len(argv) > 1:
+        raise app.UsageError('Too many command-line arguments.')
+
+    # Hide any GPUs from TensorFlow. Otherwise it might reserve memory.
+    tf.config.experimental.set_visible_devices([], 'GPU')
+
+    logging.info('JAX process: %d / %d', jax.process_index(),
+                 jax.process_count())
+    logging.info('JAX local devices: %r', jax.local_devices())
+
+    # Add a note so that we can tell which task is which JAX host.
+    # (Depending on the platform task 0 is not guaranteed to be host 0)
+    platform.work_unit().set_task_status(
+        f'process_index: {jax.process_index()}, '
+        f'process_count: {jax.process_count()}')
+    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
+                                         FLAGS.workdir, 'workdir')
+
+    logging.info(
+        'Visualize results with: bazel run -c opt @pip_deps_tensorboard//:rules_python_wheel_entry_point_tensorboard -- --logdir %s',
+        FLAGS.workdir,
+    )
+
+    physics_constants = jax_dynamics.Coefficients()
+    state = train.train(FLAGS.workdir, physics_constants)
+
+
+if __name__ == '__main__':
+    flags.mark_flags_as_required(['workdir'])
+    app.run(main)
diff --git a/frc971/control_loops/swerve/velocity_controller/model.py b/frc971/control_loops/swerve/velocity_controller/model.py
new file mode 100644
index 0000000..ae67efe
--- /dev/null
+++ b/frc971/control_loops/swerve/velocity_controller/model.py
@@ -0,0 +1,447 @@
+from __future__ import annotations
+import flax
+import flashbax
+from typing import Any
+import dataclasses
+import absl
+from absl import logging
+import numpy
+import jax
+from flax import linen as nn
+from jaxtyping import Array, ArrayLike
+import optax
+from flax.training import train_state
+from jax.experimental import mesh_utils
+from jax.sharding import Mesh, PartitionSpec, NamedSharding
+from frc971.control_loops.swerve import jax_dynamics
+from frc971.control_loops.swerve import dynamics
+from frc971.control_loops.swerve.velocity_controller import physics
+from frc971.control_loops.swerve.velocity_controller import experience_buffer
+
+from flax.typing import PRNGKey
+
+FLAGS = absl.flags.FLAGS
+
+absl.flags.DEFINE_integer(
+    'num_agents',
+    default=10,
+    help='Training batch size.',
+)
+
+absl.flags.DEFINE_float(
+    'q_learning_rate',
+    default=0.002,
+    help='Training learning rate.',
+)
+
+absl.flags.DEFINE_float(
+    'final_q_learning_rate',
+    default=0.00002,
+    help='Training learning rate.',
+)
+
+absl.flags.DEFINE_float(
+    'pi_learning_rate',
+    default=0.002,
+    help='Training learning rate.',
+)
+
+absl.flags.DEFINE_float(
+    'final_pi_learning_rate',
+    default=0.00002,
+    help='Training learning rate.',
+)
+
+absl.flags.DEFINE_integer(
+    'replay_size',
+    default=2000000,
+    help='Number of steps to save in our replay buffer',
+)
+
+absl.flags.DEFINE_integer(
+    'batch_size',
+    default=10000,
+    help='Batch size for learning Q and pi',
+)
+
+HIDDEN_WEIGHTS = 256
+
+LOG_STD_MIN = -20
+LOG_STD_MAX = 2
+
+
+def gaussian_likelihood(noise: ArrayLike, log_std: ArrayLike):
+    pre_sum = -0.5 * (noise**2 + 2 * log_std + jax.numpy.log(2 * jax.numpy.pi))
+
+    if len(pre_sum.shape) > 1:
+        return jax.numpy.sum(pre_sum, keepdims=True, axis=1)
+    else:
+        return jax.numpy.sum(pre_sum, keepdims=True)
+
+
+class SquashedGaussianMLPActor(nn.Module):
+    """Actor model."""
+
+    # Number of dimensions in the action space
+    action_space: int = 8
+
+    hidden_sizes: list[int] = dataclasses.field(
+        default_factory=lambda: [HIDDEN_WEIGHTS, HIDDEN_WEIGHTS])
+
+    # Activation function
+    activation: Callable = nn.activation.relu
+
+    # Max action we can apply
+    action_limit: float = 30.0
+
+    @nn.compact
+    def __call__(self,
+                 observations: ArrayLike,
+                 deterministic: bool = False,
+                 rng: PRNGKey | None = None):
+        x = observations
+        # Apply the dense layers
+        for i, hidden_size in enumerate(self.hidden_sizes):
+            x = nn.Dense(
+                name=f'denselayer{i}',
+                features=hidden_size,
+            )(x)
+            x = self.activation(x)
+
+        # Average policy is a dense function of the space.
+        mu = nn.Dense(
+            features=self.action_space,
+            name='mu',
+            kernel_init=nn.initializers.zeros,
+        )(x)
+
+        log_std_layer = nn.Dense(features=self.action_space,
+                                 name='log_std_layer')(x)
+
+        # Clip the log of the standard deviation in a soft manner.
+        log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (
+            flax.linen.activation.tanh(log_std_layer) + 1)
+
+        std = jax.numpy.exp(log_std)
+
+        if rng is None:
+            rng = self.make_rng('pi')
+
+        # Grab a random sample
+        random_sample = jax.random.normal(rng, shape=std.shape)
+
+        if deterministic:
+            # We are testing the optimal policy, just use the mean.
+            pi_action = mu
+        else:
+            # Use the reparameterization trick.  Adjust the unit gausian with
+            # something we can solve for to get the desired noise.
+            pi_action = random_sample * std + mu
+
+        logp_pi = gaussian_likelihood(random_sample, log_std)
+        # Adjustment to log prob
+        # NOTE: This formula is a little bit magic. To get an understanding of where it
+        # comes from, check out the original SAC paper (arXiv 1801.01290) and look in
+        # appendix C. This is a more numerically-stable equivalent to Eq 21.
+        delta = (2.0 * (jax.numpy.log(2.0) - pi_action -
+                        flax.linen.softplus(-2.0 * pi_action)))
+
+        if len(delta.shape) > 1:
+            delta = jax.numpy.sum(delta, keepdims=True, axis=1)
+        else:
+            delta = jax.numpy.sum(delta, keepdims=True)
+
+        logp_pi = logp_pi - delta
+
+        # Now, saturate the action to the limit using tanh
+        pi_action = self.action_limit * flax.linen.activation.tanh(pi_action)
+
+        return pi_action, logp_pi, self.action_limit * std, random_sample
+
+
+class MLPQFunction(nn.Module):
+
+    # Number and size of the hidden layers.
+    hidden_sizes: list[int] = dataclasses.field(
+        default_factory=lambda: [HIDDEN_WEIGHTS, HIDDEN_WEIGHTS])
+
+    activation: Callable = nn.activation.tanh
+
+    @nn.compact
+    def __call__(self, observations, actions):
+        # Estimate Q with a simple multi layer dense network.
+        x = jax.numpy.hstack((observations, actions))
+        for i, hidden_size in enumerate(self.hidden_sizes):
+            x = nn.Dense(
+                name=f'denselayer{i}',
+                features=hidden_size,
+            )(x)
+            x = self.activation(x)
+
+        x = nn.Dense(name=f'q', features=1,
+                     kernel_init=nn.initializers.zeros)(x)
+
+        return x
+
+
+class TrainState(flax.struct.PyTreeNode):
+    physics_constants: jax_dynamics.CoefficientsType = flax.struct.field(
+        pytree_node=False)
+
+    step: int | jax.Array = flax.struct.field(pytree_node=True)
+    substep: int | jax.Array = flax.struct.field(pytree_node=True)
+
+    params: flax.core.FrozenDict[str, typing.Any] = flax.struct.field(
+        pytree_node=True)
+
+    target_params: flax.core.FrozenDict[str, typing.Any] = flax.struct.field(
+        pytree_node=True)
+
+    pi_apply_fn: Callable = flax.struct.field(pytree_node=False)
+    q1_apply_fn: Callable = flax.struct.field(pytree_node=False)
+    q2_apply_fn: Callable = flax.struct.field(pytree_node=False)
+
+    pi_tx: optax.GradientTransformation = flax.struct.field(pytree_node=False)
+    pi_opt_state: optax.OptState = flax.struct.field(pytree_node=True)
+    q_tx: optax.GradientTransformation = flax.struct.field(pytree_node=False)
+    q_opt_state: optax.OptState = flax.struct.field(pytree_node=True)
+
+    alpha_tx: optax.GradientTransformation = flax.struct.field(
+        pytree_node=False)
+    alpha_opt_state: optax.OptState = flax.struct.field(pytree_node=True)
+
+    target_entropy: float = flax.struct.field(pytree_node=True)
+
+    mesh: Mesh = flax.struct.field(pytree_node=False)
+    sharding: NamedSharding = flax.struct.field(pytree_node=False)
+    replicated_sharding: NamedSharding = flax.struct.field(pytree_node=False)
+
+    replay_buffer: flashbax.buffers.trajectory_buffer.TrajectoryBuffer = flax.struct.field(
+        pytree_node=False)
+
+    def pi_apply(self,
+                 rng: PRNGKey,
+                 params,
+                 observation,
+                 deterministic: bool = False):
+        return self.pi_apply_fn({'params': params['pi']},
+                                physics.unwrap_angles(observation),
+                                deterministic=deterministic,
+                                rngs={'pi': rng})
+
+    def q1_apply(self, params, observation, action):
+        return self.q1_apply_fn({'params': params['q1']},
+                                physics.unwrap_angles(observation), action)
+
+    def q2_apply(self, params, observation, action):
+        return self.q2_apply_fn({'params': params['q2']},
+                                physics.unwrap_angles(observation), action)
+
+    def pi_apply_gradients(self, step, grads):
+        updates, new_pi_opt_state = self.pi_tx.update(grads, self.pi_opt_state,
+                                                      self.params)
+        new_params = optax.apply_updates(self.params, updates)
+
+        return self.replace(
+            step=step,
+            substep=jax.lax.select(step != self.step, 0, self.substep + 1),
+            params=new_params,
+            pi_opt_state=new_pi_opt_state,
+        )
+
+    def q_apply_gradients(self, step, grads):
+        updates, new_q_opt_state = self.q_tx.update(grads, self.q_opt_state,
+                                                    self.params)
+        new_params = optax.apply_updates(self.params, updates)
+
+        return self.replace(
+            step=step,
+            substep=jax.lax.select(step != self.step, 0, self.substep + 1),
+            params=new_params,
+            q_opt_state=new_q_opt_state,
+        )
+
+    def target_apply_gradients(self, step):
+        new_target_params = optax.incremental_update(self.params,
+                                                     self.target_params,
+                                                     1 - FLAGS.polyak)
+
+        return self.replace(
+            step=step,
+            substep=jax.lax.select(step != self.step, 0, self.substep + 1),
+            target_params=new_target_params,
+        )
+
+    def alpha_apply_gradients(self, step, grads):
+        updates, new_alpha_opt_state = self.alpha_tx.update(
+            grads, self.alpha_opt_state, self.params)
+        new_params = optax.apply_updates(self.params, updates)
+
+        return self.replace(
+            step=step,
+            substep=jax.lax.select(step != self.step, 0, self.substep + 1),
+            params=new_params,
+            alpha_opt_state=new_alpha_opt_state,
+        )
+
+    def update_step(self, step):
+        return self.replace(
+            step=step,
+            substep=jax.lax.select(step != self.step, 0, self.substep + 1),
+        )
+
+    @classmethod
+    def create(cls, *, physics_constants: jax_dynamics.CoefficientsType,
+               params, pi_tx, q_tx, alpha_tx, pi_apply_fn, q1_apply_fn,
+               q2_apply_fn, **kwargs):
+        """Creates a new instance with ``step=0`` and initialized ``opt_state``."""
+
+        pi_tx = optax.multi_transform(
+            {
+                'train': pi_tx,
+                'freeze': optax.set_to_zero()
+            },
+            param_labels=flax.traverse_util.path_aware_map(
+                lambda path, x: 'train'
+                if path[0] == 'pi' else 'freeze', params),
+        )
+        pi_opt_state = pi_tx.init(params)
+
+        q_tx = optax.multi_transform(
+            {
+                'train': q_tx,
+                'freeze': optax.set_to_zero()
+            },
+            param_labels=flax.traverse_util.path_aware_map(
+                lambda path, x: 'train'
+                if path[0] == 'q1' or path[0] == 'q2' else 'freeze', params),
+        )
+        q_opt_state = q_tx.init(params)
+
+        alpha_tx = optax.multi_transform(
+            {
+                'train': alpha_tx,
+                'freeze': optax.set_to_zero()
+            },
+            param_labels=flax.traverse_util.path_aware_map(
+                lambda path, x: 'train'
+                if path[0] == 'logalpha' else 'freeze', params),
+        )
+        alpha_opt_state = alpha_tx.init(params)
+
+        mesh = Mesh(
+            devices=mesh_utils.create_device_mesh(len(jax.devices())),
+            axis_names=('batch', ),
+        )
+        print('Devices:', jax.devices())
+        sharding = NamedSharding(mesh, PartitionSpec('batch'))
+        replicated_sharding = NamedSharding(mesh, PartitionSpec())
+
+        replay_buffer = experience_buffer.make_experience_buffer(
+            num_agents=FLAGS.num_agents,
+            sample_batch_size=FLAGS.batch_size,
+            length=FLAGS.replay_size)
+
+        result = cls(
+            physics_constants=physics_constants,
+            step=0,
+            substep=0,
+            params=params,
+            target_params=params,
+            q1_apply_fn=q1_apply_fn,
+            q2_apply_fn=q2_apply_fn,
+            pi_apply_fn=pi_apply_fn,
+            pi_tx=pi_tx,
+            pi_opt_state=pi_opt_state,
+            q_tx=q_tx,
+            q_opt_state=q_opt_state,
+            alpha_tx=alpha_tx,
+            alpha_opt_state=alpha_opt_state,
+            target_entropy=-physics.NUM_STATES,
+            mesh=mesh,
+            sharding=sharding,
+            replicated_sharding=replicated_sharding,
+            replay_buffer=replay_buffer,
+        )
+
+        return jax.device_put(result, replicated_sharding)
+
+
+def create_train_state(rng, physics_constants: jax_dynamics.CoefficientsType,
+                       q_learning_rate, pi_learning_rate):
+    """Creates initial `TrainState`."""
+    pi = SquashedGaussianMLPActor(action_space=physics.NUM_OUTPUTS)
+    # We want q1 and q2 to have different network architectures so they pick up differnet things.
+    q1 = MLPQFunction(activation=nn.activation.gelu, hidden_sizes=[128, 256])
+    q2 = MLPQFunction(activation=nn.activation.gelu, hidden_sizes=[256, 128])
+
+    @jax.jit
+    def init_params(rng):
+        pi_rng, q1_rng, q2_rng = jax.random.split(rng, num=3)
+
+        pi_params = pi.init(
+            pi_rng,
+            jax.numpy.ones([physics.NUM_UNWRAPPED_STATES]),
+        )['params']
+        q1_params = q1.init(
+            q1_rng,
+            jax.numpy.ones([physics.NUM_UNWRAPPED_STATES]),
+            jax.numpy.ones([physics.NUM_OUTPUTS]),
+        )['params']
+        q2_params = q2.init(
+            q2_rng,
+            jax.numpy.ones([physics.NUM_UNWRAPPED_STATES]),
+            jax.numpy.ones([physics.NUM_OUTPUTS]),
+        )['params']
+
+        if FLAGS.alpha < 0.0:
+            logalpha = 0.0
+        else:
+            logalpha = jax.numpy.log(FLAGS.alpha)
+
+        return {
+            'q1': q1_params,
+            'q2': q2_params,
+            'pi': pi_params,
+            'logalpha': logalpha,
+        }
+
+    pi_tx = optax.sgd(learning_rate=pi_learning_rate)
+    q_tx = optax.sgd(learning_rate=q_learning_rate)
+    alpha_tx = optax.sgd(learning_rate=q_learning_rate)
+
+    result = TrainState.create(
+        physics_constants=physics_constants,
+        params=init_params(rng),
+        pi_tx=pi_tx,
+        q_tx=q_tx,
+        alpha_tx=alpha_tx,
+        pi_apply_fn=pi.apply,
+        q1_apply_fn=q1.apply,
+        q2_apply_fn=q2.apply,
+    )
+
+    return result
+
+
+def create_learning_rate_fn(
+    base_learning_rate: float,
+    final_learning_rate: float,
+):
+    """Create learning rate schedule."""
+    warmup_fn = optax.linear_schedule(
+        init_value=base_learning_rate,
+        end_value=base_learning_rate,
+        transition_steps=FLAGS.warmup_steps,
+    )
+
+    cosine_epochs = max(FLAGS.steps - FLAGS.warmup_steps, 1)
+    cosine_fn = optax.cosine_decay_schedule(init_value=base_learning_rate,
+                                            decay_steps=cosine_epochs,
+                                            alpha=final_learning_rate)
+
+    schedule_fn = optax.join_schedules(
+        schedules=[warmup_fn, cosine_fn],
+        boundaries=[FLAGS.warmup_steps],
+    )
+    return schedule_fn
diff --git a/frc971/control_loops/swerve/velocity_controller/physics.py b/frc971/control_loops/swerve/velocity_controller/physics.py
new file mode 100644
index 0000000..9454f14
--- /dev/null
+++ b/frc971/control_loops/swerve/velocity_controller/physics.py
@@ -0,0 +1,69 @@
+import jax
+from functools import partial
+from frc971.control_loops.swerve import dynamics
+from absl import logging
+from frc971.control_loops.swerve import jax_dynamics
+
+
+@partial(jax.jit, static_argnums=[0])
+def xdot_physics(physics_constants: jax_dynamics.CoefficientsType, X, U):
+    A = jax.numpy.array([[0., 1.], [0., -36.85154548]])
+    B = jax.numpy.array([[0.], [56.08534375]])
+
+    return A @ X + B @ U
+
+
+def unwrap_angles(X):
+    return X
+
+
+@jax.jit
+def swerve_cost(X: jax.typing.ArrayLike, U: jax.typing.ArrayLike,
+                goal: jax.typing.ArrayLike):
+
+    Q = jax.numpy.array([[2.77777778, 0.], [0., 0.01]])
+    R = jax.numpy.array([[0.00694444]])
+
+    return (X).T @ Q @ (X) + U.T @ R @ U
+
+
+@partial(jax.jit, static_argnums=[0])
+def integrate_dynamics(physics_constants: jax_dynamics.CoefficientsType, X, U):
+    m = 2  # RK4 steps per interval
+    dt = 0.005 / m
+
+    def iteration(i, X):
+        weights = jax.numpy.array([[0.0, 0.5, 0.5, 1.0], [1.0, 2.0, 2.0, 1.0]])
+
+        def rk_iteration(i, val):
+            kx_previous, weighted_sum = val
+            kx = xdot_physics(physics_constants,
+                              X + dt * weights[0, i] * kx_previous, U)
+            weighted_sum += dt * weights[1, i] * kx / 6.0
+            return (kx, weighted_sum)
+
+        return jax.lax.fori_loop(lower=0,
+                                 upper=4,
+                                 body_fun=rk_iteration,
+                                 init_val=(X, X))[1]
+
+    return jax.lax.fori_loop(lower=0, upper=m, body_fun=iteration, init_val=X)
+
+
+state_cost = swerve_cost
+NUM_STATES = 2
+NUM_UNWRAPPED_STATES = 2
+NUM_OUTPUTS = 1
+ACTION_LIMIT = 10.0
+
+
+def random_states(rng, dimensions=None):
+    rng1, rng2 = jax.random.split(rng)
+
+    return jax.numpy.hstack(
+        (jax.random.uniform(rng1, (dimensions or FLAGS.num_agents, 1),
+                            minval=-1.0,
+                            maxval=1.0),
+         jax.random.uniform(rng2, (dimensions or FLAGS.num_agents, 1),
+                            minval=-10.0,
+                            maxval=10.0)))
diff --git a/frc971/control_loops/swerve/velocity_controller/train.py b/frc971/control_loops/swerve/velocity_controller/train.py
new file mode 100644
index 0000000..5002eee
--- /dev/null
+++ b/frc971/control_loops/swerve/velocity_controller/train.py
@@ -0,0 +1,511 @@
+import absl
+import time
+import collections
+from absl import logging
+import flax
+import matplotlib
+from matplotlib import pyplot
+from flax import linen as nn
+from flax.training import train_state
+from flax.training import checkpoints
+import jax
+import inspect
+import aim
+import jax.numpy as jnp
+import ml_collections
+import numpy as np
+import optax
+import numpy
+from frc971.control_loops.swerve import jax_dynamics
+from functools import partial
+import flashbax
+from jax.experimental.ode import odeint
+from frc971.control_loops.swerve import dynamics
+import orbax.checkpoint
+from frc971.control_loops.swerve.velocity_controller.model import *
+from frc971.control_loops.swerve.velocity_controller.physics import *
+
+numpy.set_printoptions(linewidth=200, )
+
+FLAGS = absl.flags.FLAGS
+
+absl.flags.DEFINE_integer(
+    'horizon',
+    default=25,
+    help='MPC horizon',
+)
+
+absl.flags.DEFINE_integer(
+    'start_steps',
+    default=10000,
+    help='Number of steps to randomly sample before using the policy',
+)
+
+absl.flags.DEFINE_integer(
+    'steps',
+    default=400000,
+    help='Number of steps to run and train the agent',
+)
+
+absl.flags.DEFINE_integer(
+    'warmup_steps',
+    default=300000,
+    help='Number of steps to warm up training',
+)
+
+absl.flags.DEFINE_float(
+    'gamma',
+    default=0.999,
+    help='Future discount.',
+)
+
+absl.flags.DEFINE_float(
+    'alpha',
+    default=0.2,
+    help='Entropy.  If negative, automatically solve for it.',
+)
+
+absl.flags.DEFINE_float(
+    'polyak',
+    default=0.995,
+    help='Time constant in polyak averaging for the target network.',
+)
+
+absl.flags.DEFINE_bool(
+    'debug_nan',
+    default=False,
+    help='If true, explode on any NaNs found, and print them.',
+)
+
+
+def random_actions(rng, dimensions=None):
+    """Produces a uniformly random action in the action space."""
+    return jax.random.uniform(
+        rng,
+        (dimensions or FLAGS.num_agents, physics.NUM_OUTPUTS),
+        minval=-physics.ACTION_LIMIT,
+        maxval=physics.ACTION_LIMIT,
+    )
+
+
+def save_checkpoint(state: TrainState, workdir: str):
+    """Saves a checkpoint in the workdir."""
+    # TODO(austin): use orbax directly.
+    step = int(state.step)
+    logging.info('Saving checkpoint step %d.', step)
+    checkpoints.save_checkpoint_multiprocess(workdir, state, step, keep=10)
+
+
+def restore_checkpoint(state: TrainState, workdir: str):
+    """Restores the latest checkpoint from the workdir."""
+    return checkpoints.restore_checkpoint(workdir, state)
+
+
+def has_nan(x):
+    if isinstance(x, jnp.ndarray):
+        return jnp.any(jnp.isnan(x))
+    else:
+        return jnp.any(
+            jax.numpy.array([has_nan(v)
+                             for v in jax.tree_util.tree_leaves(x)]))
+
+
+def print_nan(step, x):
+    if not FLAGS.debug_nan:
+        return
+
+    caller = inspect.getframeinfo(inspect.stack()[1][0])
+    # TODO(austin): It is helpful to sometimes start printing at a certain step.
+    jax.lax.cond(
+        has_nan(x), lambda: jax.debug.print(caller.filename +
+                                            ':{l} step {s} isnan(X) -> {x}',
+                                            l=caller.lineno,
+                                            s=step,
+                                            x=x), lambda: None)
+
+
+@jax.jit
+def compute_loss_q(state: TrainState, rng: PRNGKey, params, data: ArrayLike):
+    """Computes the Soft Actor-Critic loss for Q1 and Q2."""
+    observations1 = data['observations1']
+    actions = data['actions']
+    rewards = data['rewards']
+    observations2 = data['observations2']
+
+    # Compute the ending actions from the current network.
+    action2, logp_pi2, _, _ = state.pi_apply(rng=rng,
+                                             params=params,
+                                             observation=observations2)
+
+    # Compute target network Q values
+    q1_pi_target = state.q1_apply(state.target_params, observations2, action2)
+    q2_pi_target = state.q2_apply(state.target_params, observations2, action2)
+    q_pi_target = jax.numpy.minimum(q1_pi_target, q2_pi_target)
+
+    alpha = jax.numpy.exp(params['logalpha'])
+
+    # Now we can compute the Bellman backup
+    bellman_backup = jax.lax.stop_gradient(rewards + FLAGS.gamma *
+                                           (q_pi_target - alpha * logp_pi2))
+
+    # Compute the starting Q values from the Q network being optimized.
+    q1 = state.q1_apply(params, observations1, actions)
+    q2 = state.q2_apply(params, observations1, actions)
+
+    # Mean squared error loss against Bellman backup
+    q1_loss = ((q1 - bellman_backup)**2).mean()
+    q2_loss = ((q2 - bellman_backup)**2).mean()
+    return q1_loss + q2_loss
+
+
+@jax.jit
+def compute_batched_loss_q(state: TrainState, rng: PRNGKey, params,
+                           data: ArrayLike):
+
+    def bound_compute_loss_q(rng, data):
+        return compute_loss_q(state, rng, params, data)
+
+    return jax.vmap(bound_compute_loss_q)(
+        jax.random.split(rng, FLAGS.num_agents),
+        data,
+    ).mean()
+
+
+@jax.jit
+def compute_loss_pi(state: TrainState, rng: PRNGKey, params, data: ArrayLike):
+    """Computes the Soft Actor-Critic loss for pi."""
+    observations1 = data['observations1']
+    # TODO(austin): We've got differentiable policy and differentiable physics.  Can we use those here?  Have Q learn the future, not the current step?
+
+    # Compute the action
+    pi, logp_pi, _, _ = state.pi_apply(rng=rng,
+                                       params=params,
+                                       observation=observations1)
+    q1_pi = state.q1_apply(jax.lax.stop_gradient(params), observations1, pi)
+    q2_pi = state.q2_apply(jax.lax.stop_gradient(params), observations1, pi)
+
+    # And compute the Q of that action.
+    q_pi = jax.numpy.minimum(q1_pi, q2_pi)
+
+    alpha = jax.lax.stop_gradient(jax.numpy.exp(params['logalpha']))
+
+    # Compute the entropy-regularized policy loss
+    return (alpha * logp_pi - q_pi).mean()
+
+
+@jax.jit
+def compute_batched_loss_pi(state: TrainState, rng: PRNGKey, params,
+                            data: ArrayLike):
+
+    def bound_compute_loss_pi(rng, data):
+        return compute_loss_pi(state, rng, params, data)
+
+    return jax.vmap(bound_compute_loss_pi)(
+        jax.random.split(rng, FLAGS.num_agents),
+        data,
+    ).mean()
+
+
+@jax.jit
+def compute_loss_alpha(state: TrainState, rng: PRNGKey, params,
+                       data: ArrayLike):
+    """Computes the Soft Actor-Critic loss for alpha."""
+    observations1 = data['observations1']
+    pi, logp_pi, _, _ = jax.lax.stop_gradient(
+        state.pi_apply(rng=rng, params=params, observation=observations1))
+
+    return (-jax.numpy.exp(params['logalpha']) *
+            (logp_pi + state.target_entropy)).mean()
+
+
+@jax.jit
+def compute_batched_loss_alpha(state: TrainState, rng: PRNGKey, params,
+                               data: ArrayLike):
+
+    def bound_compute_loss_alpha(rng, data):
+        return compute_loss_alpha(state, rng, params, data)
+
+    return jax.vmap(bound_compute_loss_alpha)(
+        jax.random.split(rng, FLAGS.num_agents),
+        data,
+    ).mean()
+
+
+@jax.jit
+def train_step(state: TrainState, data, action_data, update_rng: PRNGKey,
+               step: int) -> TrainState:
+    """Updates the parameters for Q, Pi, target Q, and alpha."""
+    update_rng, q_grad_rng = jax.random.split(update_rng)
+    print_nan(step, data)
+
+    # Update Q
+    q_grad_fn = jax.value_and_grad(
+        lambda params: compute_batched_loss_q(state, q_grad_rng, params, data))
+    q_loss, q_grads = q_grad_fn(state.params)
+    print_nan(step, q_loss)
+    print_nan(step, q_grads)
+
+    state = state.q_apply_gradients(step=step, grads=q_grads)
+
+    update_rng, pi_grad_rng = jax.random.split(update_rng)
+
+    # Update pi
+    pi_grad_fn = jax.value_and_grad(lambda params: compute_batched_loss_pi(
+        state, pi_grad_rng, params, action_data))
+    pi_loss, pi_grads = pi_grad_fn(state.params)
+
+    print_nan(step, pi_loss)
+    print_nan(step, pi_grads)
+
+    state = state.pi_apply_gradients(step=step, grads=pi_grads)
+
+    update_rng, alpha_grad_rng = jax.random.split(update_rng)
+
+    if FLAGS.alpha < 0.0:
+        # Update alpha
+        alpha_grad_fn = jax.value_and_grad(
+            lambda params: compute_batched_loss_alpha(state, alpha_grad_rng,
+                                                      params, data))
+        alpha_loss, alpha_grads = alpha_grad_fn(state.params)
+        print_nan(step, alpha_loss)
+        print_nan(step, alpha_grads)
+        state = state.alpha_apply_gradients(step=step, grads=alpha_grads)
+    else:
+        alpha_loss = 0.0
+
+    return state, q_loss, pi_loss, alpha_loss
+
+
+@jax.jit
+def collect_experience(state: TrainState, replay_buffer_state, R: ArrayLike,
+                       step_rng: PRNGKey, step):
+    """Collects experience by simulating."""
+    pi_rng = jax.random.fold_in(step_rng, step)
+    pi_rng, initialization_rng = jax.random.split(pi_rng)
+
+    observation = jax.lax.with_sharding_constraint(
+        physics.random_states(initialization_rng, FLAGS.num_agents),
+        state.sharding)
+
+    def loop(i, val):
+        """Runs 1 step of our simulation."""
+        observation, pi_rng, replay_buffer_state = val
+        pi_rng, action_rng = jax.random.split(pi_rng)
+        logging.info('Observation shape: %s', observation.shape)
+
+        def true_fn(i):
+            # We are at the beginning of the process, pick a random action.
+            return random_actions(action_rng, FLAGS.num_agents)
+
+        def false_fn(i):
+            # We are past the beginning of the process, use the trained network.
+            pi_action, logp_pi, std, random_sample = state.pi_apply(
+                rng=action_rng,
+                params=state.params,
+                observation=observation,
+                deterministic=False)
+            return pi_action
+
+        pi_action = jax.lax.cond(
+            step <= FLAGS.start_steps,
+            true_fn,
+            false_fn,
+            i,
+        )
+
+        # Compute the destination state.
+        observation2 = jax.vmap(lambda o, pi: physics.integrate_dynamics(
+            state.physics_constants, o, pi),
+                                in_axes=(0, 0))(observation, pi_action)
+
+        # Soft Actor-Critic is designed to maximize reward.  LQR minimizes
+        # cost.  There is nothing which assumes anything about the sign of
+        # the reward, so use the negative of the cost.
+        reward = (-jax.vmap(state_cost)(observation2, pi_action, R)).reshape(
+            (FLAGS.num_agents, 1))
+
+        logging.info('Observation shape: %s', observation.shape)
+        logging.info('Observation2 shape: %s', observation2.shape)
+        logging.info('Action shape: %s', pi_action.shape)
+        logging.info('Action shape: %s', pi_action.shape)
+        replay_buffer_state = state.replay_buffer.add(
+            replay_buffer_state, {
+                'observations1': observation,
+                'observations2': observation2,
+                'actions': pi_action,
+                'rewards': reward,
+            })
+
+        return observation2, pi_rng, replay_buffer_state
+
+    # Run 1 horizon of simulation
+    final_observation, final_pi_rng, final_replay_buffer_state = jax.lax.fori_loop(
+        0, FLAGS.horizon + 1, loop, (observation, pi_rng, replay_buffer_state))
+
+    return state, final_replay_buffer_state
+
+
+@jax.jit
+def update_gradients(rng: PRNGKey, state: TrainState, replay_buffer_state,
+                     step: int):
+    rng, sample_rng = jax.random.split(rng)
+
+    action_data = state.replay_buffer.sample(replay_buffer_state, sample_rng)
+
+    def update_iteration(i, val):
+        rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, action_data = val
+        rng, sample_rng, update_rng = jax.random.split(rng, 3)
+
+        batch = state.replay_buffer.sample(replay_buffer_state, sample_rng)
+
+        print_nan(i, replay_buffer_state)
+        print_nan(i, batch)
+
+        state, q_loss, pi_loss, alpha_loss = train_step(
+            state,
+            data=batch.experience,
+            action_data=batch.experience,
+            update_rng=update_rng,
+            step=i)
+
+        return rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, action_data
+
+    rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, action_data = jax.lax.fori_loop(
+        step, step + FLAGS.horizon + 1, update_iteration,
+        (rng, state, 0.0, 0.0, 0.0, replay_buffer_state, action_data))
+
+    state = state.target_apply_gradients(step=state.step)
+
+    return rng, state, q_loss, pi_loss, alpha_loss
+
+
+def train(
+    workdir: str, physics_constants: jax_dynamics.CoefficientsType
+) -> train_state.TrainState:
+    """Trains a Soft Actor-Critic controller."""
+    rng = jax.random.key(0)
+    rng, r_rng = jax.random.split(rng)
+
+    run = aim.Run(repo='aim://127.0.0.1:53800')
+
+    run['hparams'] = {
+        'q_learning_rate': FLAGS.q_learning_rate,
+        'pi_learning_rate': FLAGS.pi_learning_rate,
+        'batch_size': FLAGS.batch_size,
+        'horizon': FLAGS.horizon,
+        'warmup_steps': FLAGS.warmup_steps,
+    }
+
+    # Setup TrainState
+    rng, init_rng = jax.random.split(rng)
+    q_learning_rate = create_learning_rate_fn(
+        base_learning_rate=FLAGS.q_learning_rate,
+        final_learning_rate=FLAGS.final_q_learning_rate)
+    pi_learning_rate = create_learning_rate_fn(
+        base_learning_rate=FLAGS.pi_learning_rate,
+        final_learning_rate=FLAGS.final_pi_learning_rate)
+    state = create_train_state(
+        init_rng,
+        physics_constants,
+        q_learning_rate=q_learning_rate,
+        pi_learning_rate=pi_learning_rate,
+    )
+    state = restore_checkpoint(state, workdir)
+
+    state_sharding = nn.get_sharding(state, state.mesh)
+    logging.info(state_sharding)
+
+    # TODO(austin): Let the goal change.
+    R = jax.numpy.hstack((
+        jax.random.uniform(rng, (FLAGS.num_agents, 1), minval=0.9, maxval=1.1),
+        jax.random.uniform(rng, (FLAGS.num_agents, 1), minval=-0.1,
+                           maxval=0.1),
+        jax.numpy.zeros((FLAGS.num_agents, 1)),
+    ))
+
+    replay_buffer_state = state.replay_buffer.init({
+        'observations1':
+        jax.numpy.zeros((physics.NUM_STATES, )),
+        'observations2':
+        jax.numpy.zeros((physics.NUM_STATES, )),
+        'actions':
+        jax.numpy.zeros((physics.NUM_OUTPUTS, )),
+        'rewards':
+        jax.numpy.zeros((1, )),
+    })
+
+    replay_buffer_state_sharding = nn.get_sharding(replay_buffer_state,
+                                                   state.mesh)
+    logging.info(replay_buffer_state_sharding)
+
+    logging.info('Shape: %s, Sharding %s',
+                 replay_buffer_state.experience['observations1'].shape,
+                 replay_buffer_state.experience['observations1'].sharding)
+    logging.info('Shape: %s, Sharding %s',
+                 replay_buffer_state.experience['observations2'].shape,
+                 replay_buffer_state.experience['observations2'].sharding)
+    logging.info('Shape: %s, Sharding %s',
+                 replay_buffer_state.experience['actions'].shape,
+                 replay_buffer_state.experience['actions'].sharding)
+    logging.info('Shape: %s, Sharding %s',
+                 replay_buffer_state.experience['rewards'].shape,
+                 replay_buffer_state.experience['rewards'].sharding)
+
+    # Number of gradients to accumulate before doing decent.
+    update_after = FLAGS.batch_size // FLAGS.num_agents
+
+    @partial(jax.jit, donate_argnums=(1, 2))
+    def train_loop(state: TrainState, replay_buffer_state, rng: PRNGKey,
+                   step: int):
+        rng, step_rng = jax.random.split(rng)
+        # Collect experience
+        state, replay_buffer_state = collect_experience(
+            state,
+            replay_buffer_state,
+            R,
+            step_rng,
+            step,
+        )
+
+        def nop(rng, state, replay_buffer_state, step):
+            return rng, state.update_step(step=step), 0.0, 0.0, 0.0
+
+        # Train
+        rng, state, q_loss, pi_loss, alpha_loss = jax.lax.cond(
+            step >= update_after, update_gradients, nop, rng, state,
+            replay_buffer_state, step)
+
+        return state, replay_buffer_state, rng, q_loss, pi_loss, alpha_loss
+
+    for step in range(0, FLAGS.steps, FLAGS.horizon):
+        state, replay_buffer_state, rng, q_loss, pi_loss, alpha_loss = train_loop(
+            state, replay_buffer_state, rng, step)
+
+        if FLAGS.debug_nan and has_nan(state.params):
+            logging.fatal('Nan params, aborting')
+
+        logging.info(
+            'Step %s: q_loss=%s, pi_loss=%s, alpha_loss=%s, q_learning_rate=%s, pi_learning_rate=%s, alpha=%s',
+            step,
+            q_loss,
+            pi_loss,
+            alpha_loss,
+            q_learning_rate(step),
+            pi_learning_rate(step),
+            jax.numpy.exp(state.params['logalpha']),
+        )
+
+        run.track(
+            {
+                'q_loss': float(q_loss),
+                'pi_loss': float(pi_loss),
+                'alpha_loss': float(alpha_loss),
+                'alpha': float(jax.numpy.exp(state.params['logalpha']))
+            },
+            step=step)
+
+        if step % 1000 == 0 and step > update_after:
+            # TODO(austin): Simulate a rollout and accumulate the reward.  How good are we doing?
+            save_checkpoint(state, workdir)
+
+    return state
diff --git a/tools/python/requirements.lock.txt b/tools/python/requirements.lock.txt
index 93171a9..4fb225e 100644
--- a/tools/python/requirements.lock.txt
+++ b/tools/python/requirements.lock.txt
@@ -20,6 +20,115 @@
     #   tensorflow
     #   tensorflow-datasets
     #   tensorflow-metadata
+aim==3.24.0 \
+    --hash=sha256:1266fac2453f2d356d2e39c98d15b46ba7a71ecab14cdaa6eaa6d390f5add898 \
+    --hash=sha256:1322659dcecb701c264f5067375c74811faf3514bb316e1ca1e53b9de6b62766 \
+    --hash=sha256:13dff702a680ad2d25344543e3b912c217a0fd305d5355f388fcd14f030af4f0 \
+    --hash=sha256:14a70cd508b4761d7d4e1f302207feb58906ac3ab27d2d64eeb4f1cbb8d4dc47 \
+    --hash=sha256:1789bcc9edf69ae90bb18a6f3fa4862095b8020eadc7e5e0e93ba4f242250947 \
+    --hash=sha256:29949cf7e5b5a46cff6a1d868805962e6563929934cd1d5d47678a58a4c9b777 \
+    --hash=sha256:2a9158717267e4f04ac12ffef8ef22c3e72af3b284ce4a86906a511d57ee98b9 \
+    --hash=sha256:2b0e41e6b46b6e435be845aeda4792cc6b392fb17bb5096e6ca186fffeb96547 \
+    --hash=sha256:2ea588e075a5508014af74be4821adc58041ac034af81b106936482043fcb124 \
+    --hash=sha256:39462568e7e4270574c20812b2ede4747e908c0320f947c39b94cdcbda63141b \
+    --hash=sha256:430000ece9e045d4b5313b0e683c363730b712260147834c4a6f89a2fd06513d \
+    --hash=sha256:4903a1ead28681dd5a51b4c8b613e6af5a158bde11d77678a64d6eeab9d9e97f \
+    --hash=sha256:663f939ebd89d053ca2a31f921cf2389f94b75fe3a97a66840b33ac0c8888b46 \
+    --hash=sha256:6937cb8736063e3734292f32aed8de141173be83d6dc9c49d8f05a58bb99a713 \
+    --hash=sha256:696aa9bb1c4316fb9171410daaff640293d5d2ee215055dd7eda27445318240b \
+    --hash=sha256:6b1b12e6724677c1bf9d9c62a0ddb30cc5b130080765fd217792e36a6622a3e3 \
+    --hash=sha256:8328b7703b7626f1f5082e947c5e203011536de039c64e551364444299ff68e1 \
+    --hash=sha256:8aff59858e5944ede0eb7ecbd56d4131418913a5cb13436dbe55e1795b85a838 \
+    --hash=sha256:92a794755de87c920bc1de49856a0ec864a903c2e4fb090a4fe6b0a71136c865 \
+    --hash=sha256:b2a2755115aa9b8efc57d262bef9ce9f13ca56068e6833bf359884228cab9803 \
+    --hash=sha256:c0f8dd9542bdf251b195d8c1680367b7c043a1fb59fd170944679157f2cf2b05 \
+    --hash=sha256:cb9d3e638598309e7e1d9970c62ec93d329ffcdeab7753886eee985b2b15a207 \
+    --hash=sha256:d52419e43d5a879455ca874843e5e45d4725022665a7c560f6b2193ad6a181a4 \
+    --hash=sha256:d63186f069f097de107a373fca2de9a5e7662c7d2ed4364b8937b5a5544e92f7 \
+    --hash=sha256:e481ef1e6e3ea17632fd4ecdee89cabb3f815929a5b080ee0e868e553f893c55 \
+    --hash=sha256:eaf98b0cba60d97cbce75a88ae96ce4c2439b3ba77572b9f17a5885fb76e9154 \
+    --hash=sha256:ec3bd85fdb5b85b9655d77fd0474e3f5c78f5065ca77c59efd77dbfae8005e40 \
+    --hash=sha256:f98f52fab910cdaee91256cf722e51b3edb131d9a0331519bccc9b6f06c28f6a
+    # via -r tools/python/requirements.txt
+aim-ui==3.24.0 \
+    --hash=sha256:d24ca6059095f76f29b4bee211b71dcfa5ed2e4fa6d16d82a343c900d5a37498
+    # via aim
+aimrecords==0.0.7 \
+    --hash=sha256:9b562fa5b5109b4b3dd4f83be0061cadbac63fa8031f281b8b5c8ae29967072f \
+    --hash=sha256:b9276890891c5fd68f817e20fc5d466a80c01e22fa468eaa979331448a75d601
+    # via aim
+aimrocks==0.5.2 \
+    --hash=sha256:03ca9bd3a7d379f40c678e648d3ec1445738a32fee337009cfb6aa9aedc51964 \
+    --hash=sha256:081b59cf3a02056420e32d8fa851859d314ae227e975d6febba67e9341135208 \
+    --hash=sha256:14982839f451f8990e9a1b5c4c06ffc77cafeb3b4a7f372f92a1da19a52a8a11 \
+    --hash=sha256:18574bab2cc060231a3da26a3ca2b18b6482b79649217e8fdcf6bc29efcdf973 \
+    --hash=sha256:1c1754ba5da8f2e016ed96b85eeb31048cc325cc351c32431d668ab226a0d986 \
+    --hash=sha256:2cfc3f4f1e4b105c1973007e4798aa3cedc0fc81436f81c220e02a00d796071c \
+    --hash=sha256:2d120e114c11882e8ac7dcd76a745b21da1fd0cd10aeaf525ec8a48a08556b3f \
+    --hash=sha256:3e580c5640c61e47591873448ddfd5741979f2bccf40809157fb260c6956f1e5 \
+    --hash=sha256:3f65583d29bcfc3baa422e45e73de89c4c781490664eb49c1a4c21c074f5bbfa \
+    --hash=sha256:4aaa2ffae1dbdcd2be21535a2866ec4a9a2fbb4338cc5b955ad6ca3fae22461c \
+    --hash=sha256:4ce617cf9f11e81a70070ea1d14fa2204996c651984cbd19178246eb33b143d0 \
+    --hash=sha256:533eda940f4bd1641fee15da09595c965d6890e449706fb3442174472b468a19 \
+    --hash=sha256:5c92e843818e7b764725c1dca1ca6744202ac46b5c246b407e39e8a28e0266a1 \
+    --hash=sha256:691621709f02841a248ed2555ec61346c50bfc07df2553be54a355c9010676e3 \
+    --hash=sha256:6b11d27df8ebec63bb9a121c55bc19c8b93801a6a9064533835e056bc5a11bb6 \
+    --hash=sha256:7452149a119d4b3620254a567c3c68ecfa3e016f58f443847e4fb70b85186593 \
+    --hash=sha256:76258350f2715d686d5da12a5a2df0d7b88e1b33e45052e0ddeb549c7497a56e \
+    --hash=sha256:762f7b41793165717a9e0589658cd81bffb54161295ec7403534d40692ac9281 \
+    --hash=sha256:7635741ed2b8dbf59c7564446bd0716dd5ea431c82753000ca4851bed9e76911 \
+    --hash=sha256:79400c6f4c72bcc4485f2a4411a3e6c1f6ead7a3928a00a72739abb1ef9ec0d3 \
+    --hash=sha256:7f47ae2f2183e1c1c9299be79ac9704b1d47e66c8ae7d41356385ad33e9def4b \
+    --hash=sha256:8b3364c566e547cd1855700cdec07149a139f6665a6ad60275e3ee3679945dbd \
+    --hash=sha256:8d685e8092db34c68ff8208c4961918345e14a0bdddf0ee22017346433950cc8 \
+    --hash=sha256:95dae89ed772439a12c845d013dca1dc3abb88ccd71ee50bc8728d43afddd7d7 \
+    --hash=sha256:96d6877437108ca8f8c3c72f27aeeb987af881ca6fe78a46962d3bf96346fd23 \
+    --hash=sha256:9becbd34b2bac33dae7db5ce85f9ef70b83b20fd547794a40b7a7bd846be45fa \
+    --hash=sha256:9c88bdacf4d977f80b3c9a7555b5e152945d66feccd6e0cc7d2135a7f477d6bc \
+    --hash=sha256:9d870581e402c718549385704d3bbb373aa6eb684b1e1bc5cb935af8ebba91b6 \
+    --hash=sha256:9db95c611b04cc4fb1796436c3e09483414636282261f5eae46c73f68bd9dce9 \
+    --hash=sha256:9de44918367c5f8f2f9b638f97c720320c5bb4fd400a76e4e94d34b6d7d41cc6 \
+    --hash=sha256:9faaecf4fe0335c27e63523f6a25e038877a33c63a261ff2192582e52493b39b \
+    --hash=sha256:a3be10ecd3373c35ce51b8b531c2ac41e11ff954ee678377512e8210a01b593b \
+    --hash=sha256:a9ba647f32934ac999c4119cbb8b59510dfe69aec98558539b84db7db9f20acf \
+    --hash=sha256:aba6d805b5370eda1c946c269d6df926a083819798603652c8f3b16469fea1d2 \
+    --hash=sha256:ada99031f85364232a5f8b6c3be0ef4835d26f03b09529c86b2b80c8b027428c \
+    --hash=sha256:b26c2d9adee42ce1611add5220567cab1a831813b0f711d7921cbaa0bf633a28 \
+    --hash=sha256:b5777e4aa7a5d2715c5bf698d4b4049d3f2b95bc0704d455ac486498820cd963 \
+    --hash=sha256:bba97b8bbe82b41e7aa52b64a80e1da5279e54153cb46a0364f9947c0655e5a8 \
+    --hash=sha256:bc9a4007ffad3d9a84188f8062ec2d2122283769956895a73e2336febc5ae8f1 \
+    --hash=sha256:be3a48210a2dc25633d53b0b7c33b362b7fce8f9fe2ad4ebb4cdc03471bdeb62 \
+    --hash=sha256:c338d07e80344e15e6c49d8c125f31492ac58c8236c82f7c0d8171c81027b4dc \
+    --hash=sha256:c98ca6955a43ed2c968ae9fe44bb1049f52d59ef319d9f78eb99dd4d3359a580 \
+    --hash=sha256:d4cc1523e7cd766937600da915b061538a44efec1f814b08b056eef3876e89cf \
+    --hash=sha256:d5e34571e930f99df9832cfeb7ad9917cbf0245ef2e3177cb82f86c29e1b273d \
+    --hash=sha256:d5e916f34a5d34d4da6a9199ece1c0d51efa93a30984ed89ad4ccaa1fb7a51c2 \
+    --hash=sha256:d74170021b17451881df18683eb0aa97417cb8b030b3dea425d7580891c22608 \
+    --hash=sha256:e018df19877ed13e93bbc8e8c32664cedade08d483eb0aa7077fc41b8eefd005 \
+    --hash=sha256:e6a79076d669dbf221442399da893baff6c1549031edbb5d556042d1b9b6525c \
+    --hash=sha256:e9a62a21266f88337e58d443ca58e85293232f543bcf0a66832fc89d4f9a320d \
+    --hash=sha256:eb4adb6bc4db3f0d3a1fe6cfec05846f76fabe5a3faeabc294c9777535351864 \
+    --hash=sha256:f479567d8514a63ee7f227d0841dc886870c37c3f6e17a8724ecaebfaa1331b2 \
+    --hash=sha256:f99ae65f910cf4a505457a280c9296212c9d844f5ece8c7d28813edf62787602 \
+    --hash=sha256:fe5b69e7dc54a07188d06fba9da012318223b60c07a96d39f90ccf1b01f833f7 \
+    --hash=sha256:ff6334af4ac403438eae330bf25fff5b3a63ba9f8f87a77ebb2d34815ef36431
+    # via aim
+aiofiles==24.1.0 \
+    --hash=sha256:22a075c9e5a3810f0c2e48f3008c94d68c65d763b9b03857924c99e57355166c \
+    --hash=sha256:b4ec55f4195e3eb5d7abd1bf7e061763e864dd4954231fb8539a0ef8bb8260e5
+    # via aim
+alembic==1.13.3 \
+    --hash=sha256:203503117415561e203aa14541740643a611f641517f0209fcae63e9fa09f1a2 \
+    --hash=sha256:908e905976d15235fae59c9ac42c4c5b75cfcefe3d27c0fbf7ae15a37715d80e
+    # via aim
+annotated-types==0.7.0 \
+    --hash=sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53 \
+    --hash=sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89
+    # via pydantic
+anyio==4.6.0 \
+    --hash=sha256:137b4559cbb034c477165047febb6ff83f390fc3b20bf181c1fc0a728cb8beeb \
+    --hash=sha256:c7d2e9d63e31599eeb636c8c5c03a7e108d73b345f064f1c19fdc87b79036a9a
+    # via starlette
 array-record==0.5.1 \
     --hash=sha256:248fb29086cb3a6322a5d8b8332d77713a030bc54f0bacdf215a6d3185f73f90 \
     --hash=sha256:6ebe99f37e3a797322f4f5cfc6902b5e852012ba2729fac628aad6affb225247 \
@@ -34,6 +143,10 @@
     --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \
     --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8
     # via tensorflow
+base58==2.0.1 \
+    --hash=sha256:365c9561d9babac1b5f18ee797508cd54937a724b6e419a130abad69cec5ca79 \
+    --hash=sha256:447adc750d6b642987ffc6d397ecd15a799852d5f6a1d308d384500243825058
+    # via aimrecords
 blinker==1.8.2 \
     --hash=sha256:1779309f71bf239144b9399d06ae925637cf6634cf6bd131104184531bf67c01 \
     --hash=sha256:8f77b09d3bf7c795e969e9486f39c2c5e9c39d4ee07424be2bc594ece9642d83
@@ -42,6 +155,20 @@
     --hash=sha256:b7c22fb0f7004b04f12e1b7b26ee0269a26737a08ded848fb58f6a34ec1eb155 \
     --hash=sha256:c6f33817f866fc67fbeb5df79cd13a8bb592c05c591f3fd7f4f22b824f7afa01
     # via -r tools/python/requirements.txt
+boto3==1.35.27 \
+    --hash=sha256:10d0fe15670b83a3f26572ab20d9152a064cee4c54b5ea9a1eeb1f0c3b807a7b \
+    --hash=sha256:3da139ca038032e92086e26d23833b557f0c257520162bfd3d6f580bf8032c86
+    # via aim
+botocore==1.35.27 \
+    --hash=sha256:c299c70b5330a8634e032883ce8a72c2c6d9fdbc985d8191199cb86b92e7cbbd \
+    --hash=sha256:f68875c26cd57a9d22c0f7a981ecb1636d7ce4d0e35797e04765b53e7bfed3e7
+    # via
+    #   boto3
+    #   s3transfer
+cachetools==5.5.0 \
+    --hash=sha256:02134e8439cdc2ffb62023ce1debca2944c3f289d66bb17ead3ab3dede74b292 \
+    --hash=sha256:2cc24fb4cbe39633fb7badd9db9ca6295d766d9c2995f245725a46715d050f2a
+    # via aim
 casadi==3.6.6 \
     --hash=sha256:0870df9ac7040c14b35fdc82b74578ccfe8f1d9d8615eb79693a560fefb42307 \
     --hash=sha256:0fd493876c673ad149b03513c4f72275611643f2225f4f5d7c7ff828f75805a1 \
@@ -97,6 +224,75 @@
     --hash=sha256:922820b53db7a7257ffbda3f597266d435245903d80737e34f8a45ff3e3230d8 \
     --hash=sha256:bec941d2aa8195e248a60b31ff9f0558284cf01a52591ceda73ea9afffd69fd9
     # via requests
+cffi==1.17.1 \
+    --hash=sha256:045d61c734659cc045141be4bae381a41d89b741f795af1dd018bfb532fd0df8 \
+    --hash=sha256:0984a4925a435b1da406122d4d7968dd861c1385afe3b45ba82b750f229811e2 \
+    --hash=sha256:0e2b1fac190ae3ebfe37b979cc1ce69c81f4e4fe5746bb401dca63a9062cdaf1 \
+    --hash=sha256:0f048dcf80db46f0098ccac01132761580d28e28bc0f78ae0d58048063317e15 \
+    --hash=sha256:1257bdabf294dceb59f5e70c64a3e2f462c30c7ad68092d01bbbfb1c16b1ba36 \
+    --hash=sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824 \
+    --hash=sha256:1d599671f396c4723d016dbddb72fe8e0397082b0a77a4fab8028923bec050e8 \
+    --hash=sha256:28b16024becceed8c6dfbc75629e27788d8a3f9030691a1dbf9821a128b22c36 \
+    --hash=sha256:2bb1a08b8008b281856e5971307cc386a8e9c5b625ac297e853d36da6efe9c17 \
+    --hash=sha256:30c5e0cb5ae493c04c8b42916e52ca38079f1b235c2f8ae5f4527b963c401caf \
+    --hash=sha256:31000ec67d4221a71bd3f67df918b1f88f676f1c3b535a7eb473255fdc0b83fc \
+    --hash=sha256:386c8bf53c502fff58903061338ce4f4950cbdcb23e2902d86c0f722b786bbe3 \
+    --hash=sha256:3edc8d958eb099c634dace3c7e16560ae474aa3803a5df240542b305d14e14ed \
+    --hash=sha256:45398b671ac6d70e67da8e4224a065cec6a93541bb7aebe1b198a61b58c7b702 \
+    --hash=sha256:46bf43160c1a35f7ec506d254e5c890f3c03648a4dbac12d624e4490a7046cd1 \
+    --hash=sha256:4ceb10419a9adf4460ea14cfd6bc43d08701f0835e979bf821052f1805850fe8 \
+    --hash=sha256:51392eae71afec0d0c8fb1a53b204dbb3bcabcb3c9b807eedf3e1e6ccf2de903 \
+    --hash=sha256:5da5719280082ac6bd9aa7becb3938dc9f9cbd57fac7d2871717b1feb0902ab6 \
+    --hash=sha256:610faea79c43e44c71e1ec53a554553fa22321b65fae24889706c0a84d4ad86d \
+    --hash=sha256:636062ea65bd0195bc012fea9321aca499c0504409f413dc88af450b57ffd03b \
+    --hash=sha256:6883e737d7d9e4899a8a695e00ec36bd4e5e4f18fabe0aca0efe0a4b44cdb13e \
+    --hash=sha256:6b8b4a92e1c65048ff98cfe1f735ef8f1ceb72e3d5f0c25fdb12087a23da22be \
+    --hash=sha256:6f17be4345073b0a7b8ea599688f692ac3ef23ce28e5df79c04de519dbc4912c \
+    --hash=sha256:706510fe141c86a69c8ddc029c7910003a17353970cff3b904ff0686a5927683 \
+    --hash=sha256:72e72408cad3d5419375fc87d289076ee319835bdfa2caad331e377589aebba9 \
+    --hash=sha256:733e99bc2df47476e3848417c5a4540522f234dfd4ef3ab7fafdf555b082ec0c \
+    --hash=sha256:7596d6620d3fa590f677e9ee430df2958d2d6d6de2feeae5b20e82c00b76fbf8 \
+    --hash=sha256:78122be759c3f8a014ce010908ae03364d00a1f81ab5c7f4a7a5120607ea56e1 \
+    --hash=sha256:805b4371bf7197c329fcb3ead37e710d1bca9da5d583f5073b799d5c5bd1eee4 \
+    --hash=sha256:85a950a4ac9c359340d5963966e3e0a94a676bd6245a4b55bc43949eee26a655 \
+    --hash=sha256:8f2cdc858323644ab277e9bb925ad72ae0e67f69e804f4898c070998d50b1a67 \
+    --hash=sha256:9755e4345d1ec879e3849e62222a18c7174d65a6a92d5b346b1863912168b595 \
+    --hash=sha256:98e3969bcff97cae1b2def8ba499ea3d6f31ddfdb7635374834cf89a1a08ecf0 \
+    --hash=sha256:a08d7e755f8ed21095a310a693525137cfe756ce62d066e53f502a83dc550f65 \
+    --hash=sha256:a1ed2dd2972641495a3ec98445e09766f077aee98a1c896dcb4ad0d303628e41 \
+    --hash=sha256:a24ed04c8ffd54b0729c07cee15a81d964e6fee0e3d4d342a27b020d22959dc6 \
+    --hash=sha256:a45e3c6913c5b87b3ff120dcdc03f6131fa0065027d0ed7ee6190736a74cd401 \
+    --hash=sha256:a9b15d491f3ad5d692e11f6b71f7857e7835eb677955c00cc0aefcd0669adaf6 \
+    --hash=sha256:ad9413ccdeda48c5afdae7e4fa2192157e991ff761e7ab8fdd8926f40b160cc3 \
+    --hash=sha256:b2ab587605f4ba0bf81dc0cb08a41bd1c0a5906bd59243d56bad7668a6fc6c16 \
+    --hash=sha256:b62ce867176a75d03a665bad002af8e6d54644fad99a3c70905c543130e39d93 \
+    --hash=sha256:c03e868a0b3bc35839ba98e74211ed2b05d2119be4e8a0f224fba9384f1fe02e \
+    --hash=sha256:c59d6e989d07460165cc5ad3c61f9fd8f1b4796eacbd81cee78957842b834af4 \
+    --hash=sha256:c7eac2ef9b63c79431bc4b25f1cd649d7f061a28808cbc6c47b534bd789ef964 \
+    --hash=sha256:c9c3d058ebabb74db66e431095118094d06abf53284d9c81f27300d0e0d8bc7c \
+    --hash=sha256:ca74b8dbe6e8e8263c0ffd60277de77dcee6c837a3d0881d8c1ead7268c9e576 \
+    --hash=sha256:caaf0640ef5f5517f49bc275eca1406b0ffa6aa184892812030f04c2abf589a0 \
+    --hash=sha256:cdf5ce3acdfd1661132f2a9c19cac174758dc2352bfe37d98aa7512c6b7178b3 \
+    --hash=sha256:d016c76bdd850f3c626af19b0542c9677ba156e4ee4fccfdd7848803533ef662 \
+    --hash=sha256:d01b12eeeb4427d3110de311e1774046ad344f5b1a7403101878976ecd7a10f3 \
+    --hash=sha256:d63afe322132c194cf832bfec0dc69a99fb9bb6bbd550f161a49e9e855cc78ff \
+    --hash=sha256:da95af8214998d77a98cc14e3a3bd00aa191526343078b530ceb0bd710fb48a5 \
+    --hash=sha256:dd398dbc6773384a17fe0d3e7eeb8d1a21c2200473ee6806bb5e6a8e62bb73dd \
+    --hash=sha256:de2ea4b5833625383e464549fec1bc395c1bdeeb5f25c4a3a82b5a8c756ec22f \
+    --hash=sha256:de55b766c7aa2e2a3092c51e0483d700341182f08e67c63630d5b6f200bb28e5 \
+    --hash=sha256:df8b1c11f177bc2313ec4b2d46baec87a5f3e71fc8b45dab2ee7cae86d9aba14 \
+    --hash=sha256:e03eab0a8677fa80d646b5ddece1cbeaf556c313dcfac435ba11f107ba117b5d \
+    --hash=sha256:e221cf152cff04059d011ee126477f0d9588303eb57e88923578ace7baad17f9 \
+    --hash=sha256:e31ae45bc2e29f6b2abd0de1cc3b9d5205aa847cafaecb8af1476a609a2f6eb7 \
+    --hash=sha256:edae79245293e15384b51f88b00613ba9f7198016a5948b5dddf4917d4d26382 \
+    --hash=sha256:f1e22e8c4419538cb197e4dd60acc919d7696e5ef98ee4da4e01d3f8cfa4cc5a \
+    --hash=sha256:f3a2b4222ce6b60e2e8b337bb9596923045681d71e5a082783484d845390938e \
+    --hash=sha256:f6a16c31041f09ead72d69f583767292f750d24913dadacf5756b966aacb3f1a \
+    --hash=sha256:f75c7ab1f9e4aca5414ed4d8e5c0e303a34f4421f8a0d47a4d019ceff0ab6af4 \
+    --hash=sha256:f79fc4fc25f1c8698ff97788206bb3c2598949bfe0fef03d299eb1b5356ada99 \
+    --hash=sha256:f7f5baafcc48261359e14bcd6d9bff6d4b28d9103847c9e136694cb0501aef87 \
+    --hash=sha256:fc48c783f9c87e60831201f2cce7f3b2e4846bf4d8728eabe54d60700b318a0b
+    # via cryptography
 charset-normalizer==3.3.2 \
     --hash=sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027 \
     --hash=sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087 \
@@ -199,9 +395,11 @@
     --hash=sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28 \
     --hash=sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de
     # via
+    #   aim
     #   flask
     #   mkdocs
     #   tensorflow-datasets
+    #   uvicorn
 clu==0.0.12 \
     --hash=sha256:0d183e7d25f7dd0700444510a264e24700e2f068bdabd199ed22866f7e54edba \
     --hash=sha256:f71eaa1afbd30f57f7709257ba7e1feb8ad5c1c3dcae3606672a138678bb3ce4
@@ -279,6 +477,35 @@
     # via
     #   bokeh
     #   matplotlib
+cryptography==43.0.1 \
+    --hash=sha256:014f58110f53237ace6a408b5beb6c427b64e084eb451ef25a28308270086494 \
+    --hash=sha256:1bbcce1a551e262dfbafb6e6252f1ae36a248e615ca44ba302df077a846a8806 \
+    --hash=sha256:203e92a75716d8cfb491dc47c79e17d0d9207ccffcbcb35f598fbe463ae3444d \
+    --hash=sha256:27e613d7077ac613e399270253259d9d53872aaf657471473ebfc9a52935c062 \
+    --hash=sha256:2bd51274dcd59f09dd952afb696bf9c61a7a49dfc764c04dd33ef7a6b502a1e2 \
+    --hash=sha256:38926c50cff6f533f8a2dae3d7f19541432610d114a70808f0926d5aaa7121e4 \
+    --hash=sha256:511f4273808ab590912a93ddb4e3914dfd8a388fed883361b02dea3791f292e1 \
+    --hash=sha256:58d4e9129985185a06d849aa6df265bdd5a74ca6e1b736a77959b498e0505b85 \
+    --hash=sha256:5b43d1ea6b378b54a1dc99dd8a2b5be47658fe9a7ce0a58ff0b55f4b43ef2b84 \
+    --hash=sha256:61ec41068b7b74268fa86e3e9e12b9f0c21fcf65434571dbb13d954bceb08042 \
+    --hash=sha256:666ae11966643886c2987b3b721899d250855718d6d9ce41b521252a17985f4d \
+    --hash=sha256:68aaecc4178e90719e95298515979814bda0cbada1256a4485414860bd7ab962 \
+    --hash=sha256:7c05650fe8023c5ed0d46793d4b7d7e6cd9c04e68eabe5b0aeea836e37bdcec2 \
+    --hash=sha256:80eda8b3e173f0f247f711eef62be51b599b5d425c429b5d4ca6a05e9e856baa \
+    --hash=sha256:8385d98f6a3bf8bb2d65a73e17ed87a3ba84f6991c155691c51112075f9ffc5d \
+    --hash=sha256:88cce104c36870d70c49c7c8fd22885875d950d9ee6ab54df2745f83ba0dc365 \
+    --hash=sha256:9d3cdb25fa98afdd3d0892d132b8d7139e2c087da1712041f6b762e4f807cc96 \
+    --hash=sha256:a575913fb06e05e6b4b814d7f7468c2c660e8bb16d8d5a1faf9b33ccc569dd47 \
+    --hash=sha256:ac119bb76b9faa00f48128b7f5679e1d8d437365c5d26f1c2c3f0da4ce1b553d \
+    --hash=sha256:c1332724be35d23a854994ff0b66530119500b6053d0bd3363265f7e5e77288d \
+    --hash=sha256:d03a475165f3134f773d1388aeb19c2d25ba88b6a9733c5c590b9ff7bbfa2e0c \
+    --hash=sha256:d75601ad10b059ec832e78823b348bfa1a59f6b8d545db3a24fd44362a1564cb \
+    --hash=sha256:de41fd81a41e53267cb020bb3a7212861da53a7d39f863585d13ea11049cf277 \
+    --hash=sha256:e710bf40870f4db63c3d7d929aa9e09e4e7ee219e703f949ec4073b4294f6172 \
+    --hash=sha256:ea25acb556320250756e53f9e20a4177515f012c9eaea17eb7587a8c4d8ae034 \
+    --hash=sha256:f98bf604c82c416bc829e490c700ca1553eafdf2912a91e23a79d97d9801372a \
+    --hash=sha256:fba1007b3ef89946dbbb515aeeb41e30203b004f0b4b00e5e16078b518563289
+    # via aim
 cycler==0.12.1 \
     --hash=sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30 \
     --hash=sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c
@@ -340,6 +567,18 @@
     #   optax
     #   orbax-checkpoint
     #   tensorflow-datasets
+exceptiongroup==1.2.2 \
+    --hash=sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b \
+    --hash=sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc
+    # via anyio
+fastapi==0.115.0 \
+    --hash=sha256:17ea427674467486e997206a5ab25760f6b09e069f099b96f5b55a32fb6f1631 \
+    --hash=sha256:f93b4ca3529a8ebc6fc3fcf710e5efa8de3df9b41570958abf1d97d843138004
+    # via aim
+filelock==3.16.1 \
+    --hash=sha256:2082e5703d51fbf98ea75855d9d5527e33d8ff23099bec374a134febee6946b0 \
+    --hash=sha256:c249fbfcd5db47e5e2d6d62198e565475ee65e4831e2561c8e313fa7eb961435
+    # via aim
 flashbax==0.1.2 \
     --hash=sha256:ac50b2580808ce63787da0ae240db14986e3404ade98a33335e6d7a5efe84859 \
     --hash=sha256:b566ac5a78975b3e0a0a404a8844a26aa45e9cacfaad2829dcbcac2ffb3d5f7a
@@ -424,6 +663,81 @@
     --hash=sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed \
     --hash=sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e
     # via tensorflow
+greenlet==3.1.1 \
+    --hash=sha256:0153404a4bb921f0ff1abeb5ce8a5131da56b953eda6e14b88dc6bbc04d2049e \
+    --hash=sha256:03a088b9de532cbfe2ba2034b2b85e82df37874681e8c470d6fb2f8c04d7e4b7 \
+    --hash=sha256:04b013dc07c96f83134b1e99888e7a79979f1a247e2a9f59697fa14b5862ed01 \
+    --hash=sha256:05175c27cb459dcfc05d026c4232f9de8913ed006d42713cb8a5137bd49375f1 \
+    --hash=sha256:09fc016b73c94e98e29af67ab7b9a879c307c6731a2c9da0db5a7d9b7edd1159 \
+    --hash=sha256:0bbae94a29c9e5c7e4a2b7f0aae5c17e8e90acbfd3bf6270eeba60c39fce3563 \
+    --hash=sha256:0fde093fb93f35ca72a556cf72c92ea3ebfda3d79fc35bb19fbe685853869a83 \
+    --hash=sha256:1443279c19fca463fc33e65ef2a935a5b09bb90f978beab37729e1c3c6c25fe9 \
+    --hash=sha256:1776fd7f989fc6b8d8c8cb8da1f6b82c5814957264d1f6cf818d475ec2bf6395 \
+    --hash=sha256:1d3755bcb2e02de341c55b4fca7a745a24a9e7212ac953f6b3a48d117d7257aa \
+    --hash=sha256:23f20bb60ae298d7d8656c6ec6db134bca379ecefadb0b19ce6f19d1f232a942 \
+    --hash=sha256:275f72decf9932639c1c6dd1013a1bc266438eb32710016a1c742df5da6e60a1 \
+    --hash=sha256:2846930c65b47d70b9d178e89c7e1a69c95c1f68ea5aa0a58646b7a96df12441 \
+    --hash=sha256:3319aa75e0e0639bc15ff54ca327e8dc7a6fe404003496e3c6925cd3142e0e22 \
+    --hash=sha256:346bed03fe47414091be4ad44786d1bd8bef0c3fcad6ed3dee074a032ab408a9 \
+    --hash=sha256:36b89d13c49216cadb828db8dfa6ce86bbbc476a82d3a6c397f0efae0525bdd0 \
+    --hash=sha256:37b9de5a96111fc15418819ab4c4432e4f3c2ede61e660b1e33971eba26ef9ba \
+    --hash=sha256:396979749bd95f018296af156201d6211240e7a23090f50a8d5d18c370084dc3 \
+    --hash=sha256:3b2813dc3de8c1ee3f924e4d4227999285fd335d1bcc0d2be6dc3f1f6a318ec1 \
+    --hash=sha256:411f015496fec93c1c8cd4e5238da364e1da7a124bcb293f085bf2860c32c6f6 \
+    --hash=sha256:47da355d8687fd65240c364c90a31569a133b7b60de111c255ef5b606f2ae291 \
+    --hash=sha256:48ca08c771c268a768087b408658e216133aecd835c0ded47ce955381105ba39 \
+    --hash=sha256:4afe7ea89de619adc868e087b4d2359282058479d7cfb94970adf4b55284574d \
+    --hash=sha256:4ce3ac6cdb6adf7946475d7ef31777c26d94bccc377e070a7986bd2d5c515467 \
+    --hash=sha256:4ead44c85f8ab905852d3de8d86f6f8baf77109f9da589cb4fa142bd3b57b475 \
+    --hash=sha256:54558ea205654b50c438029505def3834e80f0869a70fb15b871c29b4575ddef \
+    --hash=sha256:5e06afd14cbaf9e00899fae69b24a32f2196c19de08fcb9f4779dd4f004e5e7c \
+    --hash=sha256:62ee94988d6b4722ce0028644418d93a52429e977d742ca2ccbe1c4f4a792511 \
+    --hash=sha256:63e4844797b975b9af3a3fb8f7866ff08775f5426925e1e0bbcfe7932059a12c \
+    --hash=sha256:6510bf84a6b643dabba74d3049ead221257603a253d0a9873f55f6a59a65f822 \
+    --hash=sha256:667a9706c970cb552ede35aee17339a18e8f2a87a51fba2ed39ceeeb1004798a \
+    --hash=sha256:6ef9ea3f137e5711f0dbe5f9263e8c009b7069d8a1acea822bd5e9dae0ae49c8 \
+    --hash=sha256:7017b2be767b9d43cc31416aba48aab0d2309ee31b4dbf10a1d38fb7972bdf9d \
+    --hash=sha256:7124e16b4c55d417577c2077be379514321916d5790fa287c9ed6f23bd2ffd01 \
+    --hash=sha256:73aaad12ac0ff500f62cebed98d8789198ea0e6f233421059fa68a5aa7220145 \
+    --hash=sha256:77c386de38a60d1dfb8e55b8c1101d68c79dfdd25c7095d51fec2dd800892b80 \
+    --hash=sha256:7876452af029456b3f3549b696bb36a06db7c90747740c5302f74a9e9fa14b13 \
+    --hash=sha256:7939aa3ca7d2a1593596e7ac6d59391ff30281ef280d8632fa03d81f7c5f955e \
+    --hash=sha256:8320f64b777d00dd7ccdade271eaf0cad6636343293a25074cc5566160e4de7b \
+    --hash=sha256:85f3ff71e2e60bd4b4932a043fbbe0f499e263c628390b285cb599154a3b03b1 \
+    --hash=sha256:8b8b36671f10ba80e159378df9c4f15c14098c4fd73a36b9ad715f057272fbef \
+    --hash=sha256:93147c513fac16385d1036b7e5b102c7fbbdb163d556b791f0f11eada7ba65dc \
+    --hash=sha256:935e943ec47c4afab8965954bf49bfa639c05d4ccf9ef6e924188f762145c0ff \
+    --hash=sha256:94b6150a85e1b33b40b1464a3f9988dcc5251d6ed06842abff82e42632fac120 \
+    --hash=sha256:94ebba31df2aa506d7b14866fed00ac141a867e63143fe5bca82a8e503b36437 \
+    --hash=sha256:95ffcf719966dd7c453f908e208e14cde192e09fde6c7186c8f1896ef778d8cd \
+    --hash=sha256:98884ecf2ffb7d7fe6bd517e8eb99d31ff7855a840fa6d0d63cd07c037f6a981 \
+    --hash=sha256:99cfaa2110534e2cf3ba31a7abcac9d328d1d9f1b95beede58294a60348fba36 \
+    --hash=sha256:9e8f8c9cb53cdac7ba9793c276acd90168f416b9ce36799b9b885790f8ad6c0a \
+    --hash=sha256:a0dfc6c143b519113354e780a50381508139b07d2177cb6ad6a08278ec655798 \
+    --hash=sha256:b2795058c23988728eec1f36a4e5e4ebad22f8320c85f3587b539b9ac84128d7 \
+    --hash=sha256:b42703b1cf69f2aa1df7d1030b9d77d3e584a70755674d60e710f0af570f3761 \
+    --hash=sha256:b7cede291382a78f7bb5f04a529cb18e068dd29e0fb27376074b6d0317bf4dd0 \
+    --hash=sha256:b8a678974d1f3aa55f6cc34dc480169d58f2e6d8958895d68845fa4ab566509e \
+    --hash=sha256:b8da394b34370874b4572676f36acabac172602abf054cbc4ac910219f3340af \
+    --hash=sha256:c3a701fe5a9695b238503ce5bbe8218e03c3bcccf7e204e455e7462d770268aa \
+    --hash=sha256:c4aab7f6381f38a4b42f269057aee279ab0fc7bf2e929e3d4abfae97b682a12c \
+    --hash=sha256:ca9d0ff5ad43e785350894d97e13633a66e2b50000e8a183a50a88d834752d42 \
+    --hash=sha256:d0028e725ee18175c6e422797c407874da24381ce0690d6b9396c204c7f7276e \
+    --hash=sha256:d21e10da6ec19b457b82636209cbe2331ff4306b54d06fa04b7c138ba18c8a81 \
+    --hash=sha256:d5e975ca70269d66d17dd995dafc06f1b06e8cb1ec1e9ed54c1d1e4a7c4cf26e \
+    --hash=sha256:da7a9bff22ce038e19bf62c4dd1ec8391062878710ded0a845bcf47cc0200617 \
+    --hash=sha256:db32b5348615a04b82240cc67983cb315309e88d444a288934ee6ceaebcad6cc \
+    --hash=sha256:dcc62f31eae24de7f8dce72134c8651c58000d3b1868e01392baea7c32c247de \
+    --hash=sha256:dfc59d69fc48664bc693842bd57acfdd490acafda1ab52c7836e3fc75c90a111 \
+    --hash=sha256:e347b3bfcf985a05e8c0b7d462ba6f15b1ee1c909e2dcad795e49e91b152c383 \
+    --hash=sha256:e4d333e558953648ca09d64f13e6d8f0523fa705f51cae3f03b5983489958c70 \
+    --hash=sha256:ed10eac5830befbdd0c32f83e8aa6288361597550ba669b04c48f0f9a2c843c6 \
+    --hash=sha256:efc0f674aa41b92da8c49e0346318c6075d734994c3c4e4430b1c3f853e498e4 \
+    --hash=sha256:f1695e76146579f8c06c1509c7ce4dfe0706f49c6831a817ac04eebb2fd02011 \
+    --hash=sha256:f1d4aeb8891338e60d1ab6127af1fe45def5259def8094b9c7e34690c8858803 \
+    --hash=sha256:f406b22b7c9a9b4f8aa9d2ab13d6ae0ac3e85c9a809bd590ad53fed2bf70dc79 \
+    --hash=sha256:f6ff3b14f2df4c41660a7dec01045a045653998784bf8cfcb5a525bdffffbc8f
+    # via sqlalchemy
 grpcio==1.66.1 \
     --hash=sha256:0e6c9b42ded5d02b6b1fea3a25f036a2236eeb75d0579bfd43c0018c88bf0a3e \
     --hash=sha256:161d5c535c2bdf61b95080e7f0f017a1dfcb812bf54093e71e5562b16225b4ce \
@@ -474,6 +788,10 @@
     # via
     #   tensorboard
     #   tensorflow
+h11==0.14.0 \
+    --hash=sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d \
+    --hash=sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761
+    # via uvicorn
 h5py==3.11.0 \
     --hash=sha256:083e0329ae534a264940d6513f47f5ada617da536d8dccbafc3026aefc33c90e \
     --hash=sha256:1625fd24ad6cfc9c1ccd44a66dac2396e7ee74940776792772819fc69f3a3731 \
@@ -506,7 +824,9 @@
 idna==3.8 \
     --hash=sha256:050b4e5baadcd44d760cedbd2b8e639f2ff89bbc7a5730fcc662954303377aac \
     --hash=sha256:d838c2c0ed6fced7693d5e8ab8e734d5f8fda53a039c0164afb0b82e771e3603
-    # via requests
+    # via
+    #   anyio
+    #   requests
 importlib-metadata==8.4.0 \
     --hash=sha256:66f342cc6ac9818fc6ff340576acd24d65ba0b3efabb2b4ac08b598965a4a2f1 \
     --hash=sha256:9a547d3bc3608b025f93d403fdd1aae741c24fbb8314df4b155675742ce303c5
@@ -589,9 +909,16 @@
     --hash=sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d
     # via
     #   -r tools/python/requirements.txt
+    #   aim
     #   bokeh
     #   flask
     #   mkdocs
+jmespath==1.0.1 \
+    --hash=sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980 \
+    --hash=sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe
+    # via
+    #   boto3
+    #   botocore
 keras==3.5.0 \
     --hash=sha256:53ae4f9472ec9d9c6941c82a3fda86969724ace3b7630a94ba0a1f17ba1065c3 \
     --hash=sha256:d37a3c623935713473ceb25241b52bce9d1e0ff5b36e5d0f6f47ed55f8500c9a
@@ -703,6 +1030,10 @@
     --hash=sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b \
     --hash=sha256:cf4a99b05376513717ab5d82a0db832c56ccea4fd61a69dbb7bccf2dfb207dbe
     # via tensorflow
+mako==1.3.5 \
+    --hash=sha256:260f1dbc3a519453a9c856dedfe4beb4e50bd5a26d96386cb6c80856556bb91a \
+    --hash=sha256:48dbc20568c1d276a2698b36d968fa76161bf127194907ea6fc594fa81f943bc
+    # via alembic
 markdown==3.7 \
     --hash=sha256:2ae2471477cfd02dbbf038d5d9bc226d40def84b4fe2986e49b59b6b472bbed2 \
     --hash=sha256:7eb6df5690b81a1d7942992c97fad2938e956e79df20cbc6186e9c3a77b1c803
@@ -776,6 +1107,7 @@
     --hash=sha256:ffee1f21e5ef0d712f9033568f8344d5da8cc2869dbd08d87c84656e6a2d2f68
     # via
     #   jinja2
+    #   mako
     #   mkdocs
     #   werkzeug
 matplotlib==3.9.2 \
@@ -976,6 +1308,7 @@
     --hash=sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f
     # via
     #   -r tools/python/requirements.txt
+    #   aim
     #   bokeh
     #   casadi
     #   chex
@@ -1195,6 +1528,7 @@
     --hash=sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002 \
     --hash=sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124
     # via
+    #   aim
     #   bokeh
     #   clu
     #   keras
@@ -1319,6 +1653,7 @@
     --hash=sha256:ff25afb18123cea58a591ea0244b92eb1e61a1fd497bf6d6384f09bc3262ec3e \
     --hash=sha256:ff337c552345e95702c5fde3158acb0625111017d0e5f24bf3acdb9cc16b90d1
     # via
+    #   aim
     #   bokeh
     #   matplotlib
 pkginfo==1.11.1 \
@@ -1381,7 +1716,9 @@
     --hash=sha256:e2e8d0054fc88153ca0544f5c4d554d42e33df2e009c4ff42284ac9ebdef4132 \
     --hash=sha256:fc8c9510cde0146432bbdb433322861ee8c3efbf8589865c8bf8d21cb30c4d14 \
     --hash=sha256:ffe7fc9b6b36beadc8c322f84e1caff51e8703b88eee1da46d1e3a6ae11b4fd0
-    # via tensorflow-datasets
+    # via
+    #   aim
+    #   tensorflow-datasets
 pycairo==1.26.1 \
     --hash=sha256:067191315c3b4d09cad1ec57cdb8fc1d72e2574e89389c268a94f22d4fa98b5f \
     --hash=sha256:22e1db531d4ed3167a98f0ea165bfa2a30df9d6eb22361c38158c031065999a4 \
@@ -1399,6 +1736,105 @@
     --hash=sha256:ce049930e294c29b53c68dcaab3df97cc5de7eb1d3d8e8a9f5c77e7164cd6e85 \
     --hash=sha256:e68300d1c2196d1d34de3432885ae9ff78e10426fa16f765742a11c6f8fe0a71
     # via pygobject
+pycparser==2.22 \
+    --hash=sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6 \
+    --hash=sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc
+    # via cffi
+pydantic==2.9.2 \
+    --hash=sha256:d155cef71265d1e9807ed1c32b4c8deec042a44a50a4188b25ac67ecd81a9c0f \
+    --hash=sha256:f048cec7b26778210e28a0459867920654d48e5e62db0958433636cde4254f12
+    # via fastapi
+pydantic-core==2.23.4 \
+    --hash=sha256:0a7df63886be5e270da67e0966cf4afbae86069501d35c8c1b3b6c168f42cb36 \
+    --hash=sha256:0cb3da3fd1b6a5d0279a01877713dbda118a2a4fc6f0d821a57da2e464793f05 \
+    --hash=sha256:0dbd8dbed2085ed23b5c04afa29d8fd2771674223135dc9bc937f3c09284d071 \
+    --hash=sha256:0dff76e0602ca7d4cdaacc1ac4c005e0ce0dcfe095d5b5259163a80d3a10d327 \
+    --hash=sha256:1278e0d324f6908e872730c9102b0112477a7f7cf88b308e4fc36ce1bdb6d58c \
+    --hash=sha256:128585782e5bfa515c590ccee4b727fb76925dd04a98864182b22e89a4e6ed36 \
+    --hash=sha256:1498bec4c05c9c787bde9125cfdcc63a41004ff167f495063191b863399b1a29 \
+    --hash=sha256:19442362866a753485ba5e4be408964644dd6a09123d9416c54cd49171f50744 \
+    --hash=sha256:1b84d168f6c48fabd1f2027a3d1bdfe62f92cade1fb273a5d68e621da0e44e6d \
+    --hash=sha256:1e90d2e3bd2c3863d48525d297cd143fe541be8bbf6f579504b9712cb6b643ec \
+    --hash=sha256:20152074317d9bed6b7a95ade3b7d6054845d70584216160860425f4fbd5ee9e \
+    --hash=sha256:216f9b2d7713eb98cb83c80b9c794de1f6b7e3145eef40400c62e86cee5f4e1e \
+    --hash=sha256:233710f069d251feb12a56da21e14cca67994eab08362207785cf8c598e74577 \
+    --hash=sha256:255a8ef062cbf6674450e668482456abac99a5583bbafb73f9ad469540a3a232 \
+    --hash=sha256:2584f7cf844ac4d970fba483a717dbe10c1c1c96a969bf65d61ffe94df1b2863 \
+    --hash=sha256:2971bb5ffe72cc0f555c13e19b23c85b654dd2a8f7ab493c262071377bfce9f6 \
+    --hash=sha256:29d2c342c4bc01b88402d60189f3df065fb0dda3654744d5a165a5288a657368 \
+    --hash=sha256:2e203fdf807ac7e12ab59ca2bfcabb38c7cf0b33c41efeb00f8e5da1d86af480 \
+    --hash=sha256:33e3d65a85a2a4a0dc3b092b938a4062b1a05f3a9abde65ea93b233bca0e03f2 \
+    --hash=sha256:374a5e5049eda9e0a44c696c7ade3ff355f06b1fe0bb945ea3cac2bc336478a2 \
+    --hash=sha256:37b0fe330e4a58d3c58b24d91d1eb102aeec675a3db4c292ec3928ecd892a9a6 \
+    --hash=sha256:3d5639516376dce1940ea36edf408c554475369f5da2abd45d44621cb616f769 \
+    --hash=sha256:42c6dcb030aefb668a2b7009c85b27f90e51e6a3b4d5c9bc4c57631292015b0d \
+    --hash=sha256:4a7cd62e831afe623fbb7aabbb4fe583212115b3ef38a9f6b71869ba644624a2 \
+    --hash=sha256:4ba762ed58e8d68657fc1281e9bb72e1c3e79cc5d464be146e260c541ec12d84 \
+    --hash=sha256:4fc714bdbfb534f94034efaa6eadd74e5b93c8fa6315565a222f7b6f42ca1166 \
+    --hash=sha256:4ffa2ebd4c8530079140dd2d7f794a9d9a73cbb8e9d59ffe24c63436efa8f271 \
+    --hash=sha256:5a1504ad17ba4210df3a045132a7baeeba5a200e930f57512ee02909fc5c4cb5 \
+    --hash=sha256:5c364564d17da23db1106787675fc7af45f2f7b58b4173bfdd105564e132e6fb \
+    --hash=sha256:5e11661ce0fd30a6790e8bcdf263b9ec5988e95e63cf901972107efc49218b13 \
+    --hash=sha256:5f54b118ce5de9ac21c363d9b3caa6c800341e8c47a508787e5868c6b79c9323 \
+    --hash=sha256:5f5ff8d839f4566a474a969508fe1c5e59c31c80d9e140566f9a37bba7b8d556 \
+    --hash=sha256:61817945f2fe7d166e75fbfb28004034b48e44878177fc54d81688e7b85a3665 \
+    --hash=sha256:624e278a7d29b6445e4e813af92af37820fafb6dcc55c012c834f9e26f9aaaef \
+    --hash=sha256:63e46b3169866bd62849936de036f901a9356e36376079b05efa83caeaa02ceb \
+    --hash=sha256:6531b7ca5f951d663c339002e91aaebda765ec7d61b7d1e3991051906ddde119 \
+    --hash=sha256:68665f4c17edcceecc112dfed5dbe6f92261fb9d6054b47d01bf6371a6196126 \
+    --hash=sha256:696dd8d674d6ce621ab9d45b205df149399e4bb9aa34102c970b721554828510 \
+    --hash=sha256:6f783e0ec4803c787bcea93e13e9932edab72068f68ecffdf86a99fd5918878b \
+    --hash=sha256:723314c1d51722ab28bfcd5240d858512ffd3116449c557a1336cbe3919beb87 \
+    --hash=sha256:74b9127ffea03643e998e0c5ad9bd3811d3dac8c676e47db17b0ee7c3c3bf35f \
+    --hash=sha256:7530e201d10d7d14abce4fb54cfe5b94a0aefc87da539d0346a484ead376c3cc \
+    --hash=sha256:77733e3892bb0a7fa797826361ce8a9184d25c8dffaec60b7ffe928153680ba8 \
+    --hash=sha256:78ddaaa81421a29574a682b3179d4cf9e6d405a09b99d93ddcf7e5239c742e21 \
+    --hash=sha256:7c9129eb40958b3d4500fa2467e6a83356b3b61bfff1b414c7361d9220f9ae8f \
+    --hash=sha256:7d32706badfe136888bdea71c0def994644e09fff0bfe47441deaed8e96fdbc6 \
+    --hash=sha256:81965a16b675b35e1d09dd14df53f190f9129c0202356ed44ab2728b1c905658 \
+    --hash=sha256:8394d940e5d400d04cad4f75c0598665cbb81aecefaca82ca85bd28264af7f9b \
+    --hash=sha256:86d2f57d3e1379a9525c5ab067b27dbb8a0642fb5d454e17a9ac434f9ce523e3 \
+    --hash=sha256:883a91b5dd7d26492ff2f04f40fbb652de40fcc0afe07e8129e8ae779c2110eb \
+    --hash=sha256:88ad334a15b32a791ea935af224b9de1bf99bcd62fabf745d5f3442199d86d59 \
+    --hash=sha256:9261d3ce84fa1d38ed649c3638feefeae23d32ba9182963e465d58d62203bd24 \
+    --hash=sha256:97df63000f4fea395b2824da80e169731088656d1818a11b95f3b173747b6cd9 \
+    --hash=sha256:98d134c954828488b153d88ba1f34e14259284f256180ce659e8d83e9c05eaa3 \
+    --hash=sha256:996a38a83508c54c78a5f41456b0103c30508fed9abcad0a59b876d7398f25fd \
+    --hash=sha256:9a5bce9d23aac8f0cf0836ecfc033896aa8443b501c58d0602dbfd5bd5b37753 \
+    --hash=sha256:9a6b5099eeec78827553827f4c6b8615978bb4b6a88e5d9b93eddf8bb6790f55 \
+    --hash=sha256:9d18368b137c6295db49ce7218b1a9ba15c5bc254c96d7c9f9e924a9bc7825ad \
+    --hash=sha256:a4fa4fc04dff799089689f4fd502ce7d59de529fc2f40a2c8836886c03e0175a \
+    --hash=sha256:a5c7ba8ffb6d6f8f2ab08743be203654bb1aaa8c9dcb09f82ddd34eadb695605 \
+    --hash=sha256:aea443fffa9fbe3af1a9ba721a87f926fe548d32cab71d188a6ede77d0ff244e \
+    --hash=sha256:b10bd51f823d891193d4717448fab065733958bdb6a6b351967bd349d48d5c9b \
+    --hash=sha256:ba1a0996f6c2773bd83e63f18914c1de3c9dd26d55f4ac302a7efe93fb8e7433 \
+    --hash=sha256:bb2802e667b7051a1bebbfe93684841cc9351004e2badbd6411bf357ab8d5ac8 \
+    --hash=sha256:cfdd16ab5e59fc31b5e906d1a3f666571abc367598e3e02c83403acabc092e07 \
+    --hash=sha256:d06b0c8da4f16d1d1e352134427cb194a0a6e19ad5db9161bf32b2113409e728 \
+    --hash=sha256:d0776dea117cf5272382634bd2a5c1b6eb16767c223c6a5317cd3e2a757c61a0 \
+    --hash=sha256:d18ca8148bebe1b0a382a27a8ee60350091a6ddaf475fa05ef50dc35b5df6327 \
+    --hash=sha256:d4488a93b071c04dc20f5cecc3631fc78b9789dd72483ba15d423b5b3689b555 \
+    --hash=sha256:d5f7a395a8cf1621939692dba2a6b6a830efa6b3cee787d82c7de1ad2930de64 \
+    --hash=sha256:d7a80d21d613eec45e3d41eb22f8f94ddc758a6c4720842dc74c0581f54993d6 \
+    --hash=sha256:d97683ddee4723ae8c95d1eddac7c192e8c552da0c73a925a89fa8649bf13eea \
+    --hash=sha256:dcedcd19a557e182628afa1d553c3895a9f825b936415d0dbd3cd0bbcfd29b4b \
+    --hash=sha256:de6d1d1b9e5101508cb37ab0d972357cac5235f5c6533d1071964c47139257df \
+    --hash=sha256:df49e7a0861a8c36d089c1ed57d308623d60416dab2647a4a17fe050ba85de0e \
+    --hash=sha256:df933278128ea1cd77772673c73954e53a1c95a4fdf41eef97c2b779271bd0bd \
+    --hash=sha256:e08277a400de01bc72436a0ccd02bdf596631411f592ad985dcee21445bd0068 \
+    --hash=sha256:e38e63e6f3d1cec5a27e0afe90a085af8b6806ee208b33030e65b6516353f1a3 \
+    --hash=sha256:e55541f756f9b3ee346b840103f32779c695a19826a4c442b7954550a0972040 \
+    --hash=sha256:ec4e55f79b1c4ffb2eecd8a0cfba9955a2588497d96851f4c8f99aa4a1d39b12 \
+    --hash=sha256:ed1a53de42fbe34853ba90513cea21673481cd81ed1be739f7f2efb931b24916 \
+    --hash=sha256:ed541d70698978a20eb63d8c5d72f2cc6d7079d9d90f6b50bad07826f1320f5f \
+    --hash=sha256:f09e2ff1f17c2b51f2bc76d1cc33da96298f0a036a137f5440ab3ec5360b624f \
+    --hash=sha256:f220b0eea5965dec25480b6333c788fb72ce5f9129e8759ef876a1d805d00801 \
+    --hash=sha256:f3e0da4ebaef65158d4dfd7d3678aad692f7666877df0002b8a522cdf088f231 \
+    --hash=sha256:f455ee30a9d61d3e1a15abd5068827773d6e4dc513e795f380cdd59932c782d5 \
+    --hash=sha256:f5ef8f42bec47f21d07668a043f077d507e5bf4e668d5c6dfe6aaba89de1a5b8 \
+    --hash=sha256:f69a8e0b033b747bb3e36a44e7732f0c99f7edd5cea723d45bc0d6e95377ffee \
+    --hash=sha256:ff02b6d461a6de369f07ec15e465a88895f3223eb75073ffea56b84d9331f607
+    # via pydantic
 pygments==2.18.0 \
     --hash=sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199 \
     --hash=sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a
@@ -1414,6 +1850,8 @@
     --hash=sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3 \
     --hash=sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427
     # via
+    #   aim
+    #   botocore
     #   ghp-import
     #   matplotlib
     #   pandas
@@ -1425,7 +1863,9 @@
 pytz==2024.1 \
     --hash=sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812 \
     --hash=sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319
-    # via pandas
+    # via
+    #   aim
+    #   pandas
 pyyaml==6.0.2 \
     --hash=sha256:01179a4a8559ab5de078078f37e5c1a30d76bb88519906844fd7bdea1b7729ff \
     --hash=sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48 \
@@ -1534,14 +1974,23 @@
     --hash=sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6
     # via
     #   -r tools/python/requirements.txt
+    #   aim
     #   tensorflow
     #   tensorflow-datasets
+restrictedpython==7.2 \
+    --hash=sha256:139cb41da6e57521745a566d05825f7a09e6a884f7fa922568cff0a70b84ce6b \
+    --hash=sha256:4d1d30f709a6621ca7c4236f08b67b732a651c8099145f137078c7dd4bed3d21
+    # via aim
 rich==13.8.0 \
     --hash=sha256:2e85306a063b9492dffc86278197a60cbece75bcb766022f3436f567cae11bdc \
     --hash=sha256:a5ac1f1cd448ade0d59cc3356f7db7a7ccda2c8cbae9c7a90c28ff463d3e91f4
     # via
     #   flax
     #   keras
+s3transfer==0.10.2 \
+    --hash=sha256:0711534e9356d3cc692fdde846b4a1e4b0cb6519971860796e6bc4c7aea00ef6 \
+    --hash=sha256:eca1c20de70a39daee580aef4986996620f365c4e0fda6a86100231d62f1bf69
+    # via boto3
 scipy==1.13.1 \
     --hash=sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d \
     --hash=sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c \
@@ -1630,6 +2079,67 @@
     #   python-dateutil
     #   tensorboard
     #   tensorflow
+sniffio==1.3.1 \
+    --hash=sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2 \
+    --hash=sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc
+    # via anyio
+sqlalchemy==2.0.35 \
+    --hash=sha256:016b2e665f778f13d3c438651dd4de244214b527a275e0acf1d44c05bc6026a9 \
+    --hash=sha256:032d979ce77a6c2432653322ba4cbeabf5a6837f704d16fa38b5a05d8e21fa00 \
+    --hash=sha256:0375a141e1c0878103eb3d719eb6d5aa444b490c96f3fedab8471c7f6ffe70ee \
+    --hash=sha256:042622a5306c23b972192283f4e22372da3b8ddf5f7aac1cc5d9c9b222ab3ff6 \
+    --hash=sha256:05c3f58cf91683102f2f0265c0db3bd3892e9eedabe059720492dbaa4f922da1 \
+    --hash=sha256:0630774b0977804fba4b6bbea6852ab56c14965a2b0c7fc7282c5f7d90a1ae72 \
+    --hash=sha256:0f9f3f9a3763b9c4deb8c5d09c4cc52ffe49f9876af41cc1b2ad0138878453cf \
+    --hash=sha256:1b56961e2d31389aaadf4906d453859f35302b4eb818d34a26fab72596076bb8 \
+    --hash=sha256:22b83aed390e3099584b839b93f80a0f4a95ee7f48270c97c90acd40ee646f0b \
+    --hash=sha256:25b0f63e7fcc2a6290cb5f7f5b4fc4047843504983a28856ce9b35d8f7de03cc \
+    --hash=sha256:2a275a806f73e849e1c309ac11108ea1a14cd7058577aba962cd7190e27c9e3c \
+    --hash=sha256:2ab3f0336c0387662ce6221ad30ab3a5e6499aab01b9790879b6578fd9b8faa1 \
+    --hash=sha256:2e795c2f7d7249b75bb5f479b432a51b59041580d20599d4e112b5f2046437a3 \
+    --hash=sha256:3655af10ebcc0f1e4e06c5900bb33e080d6a1fa4228f502121f28a3b1753cde5 \
+    --hash=sha256:4668bd8faf7e5b71c0319407b608f278f279668f358857dbfd10ef1954ac9f90 \
+    --hash=sha256:4c31943b61ed8fdd63dfd12ccc919f2bf95eefca133767db6fbbd15da62078ec \
+    --hash=sha256:4fdcd72a789c1c31ed242fd8c1bcd9ea186a98ee8e5408a50e610edfef980d71 \
+    --hash=sha256:627dee0c280eea91aed87b20a1f849e9ae2fe719d52cbf847c0e0ea34464b3f7 \
+    --hash=sha256:67219632be22f14750f0d1c70e62f204ba69d28f62fd6432ba05ab295853de9b \
+    --hash=sha256:6921ee01caf375363be5e9ae70d08ce7ca9d7e0e8983183080211a062d299468 \
+    --hash=sha256:69683e02e8a9de37f17985905a5eca18ad651bf592314b4d3d799029797d0eb3 \
+    --hash=sha256:6a93c5a0dfe8d34951e8a6f499a9479ffb9258123551fa007fc708ae2ac2bc5e \
+    --hash=sha256:732e026240cdd1c1b2e3ac515c7a23820430ed94292ce33806a95869c46bd139 \
+    --hash=sha256:7befc148de64b6060937231cbff8d01ccf0bfd75aa26383ffdf8d82b12ec04ff \
+    --hash=sha256:890da8cd1941fa3dab28c5bac3b9da8502e7e366f895b3b8e500896f12f94d11 \
+    --hash=sha256:89b64cd8898a3a6f642db4eb7b26d1b28a497d4022eccd7717ca066823e9fb01 \
+    --hash=sha256:8a6219108a15fc6d24de499d0d515c7235c617b2540d97116b663dade1a54d62 \
+    --hash=sha256:8cdf1a0dbe5ced887a9b127da4ffd7354e9c1a3b9bb330dce84df6b70ccb3a8d \
+    --hash=sha256:8d625eddf7efeba2abfd9c014a22c0f6b3796e0ffb48f5d5ab106568ef01ff5a \
+    --hash=sha256:93a71c8601e823236ac0e5d087e4f397874a421017b3318fd92c0b14acf2b6db \
+    --hash=sha256:9509c4123491d0e63fb5e16199e09f8e262066e58903e84615c301dde8fa2e87 \
+    --hash=sha256:a29762cd3d116585278ffb2e5b8cc311fb095ea278b96feef28d0b423154858e \
+    --hash=sha256:a62dd5d7cc8626a3634208df458c5fe4f21200d96a74d122c83bc2015b333bc1 \
+    --hash=sha256:ada603db10bb865bbe591939de854faf2c60f43c9b763e90f653224138f910d9 \
+    --hash=sha256:aee110e4ef3c528f3abbc3c2018c121e708938adeeff9006428dd7c8555e9b3f \
+    --hash=sha256:b76d63495b0508ab9fc23f8152bac63205d2a704cd009a2b0722f4c8e0cba8e0 \
+    --hash=sha256:c0d8326269dbf944b9201911b0d9f3dc524d64779a07518199a58384c3d37a44 \
+    --hash=sha256:c41411e192f8d3ea39ea70e0fae48762cd11a2244e03751a98bd3c0ca9a4e936 \
+    --hash=sha256:c68fe3fcde03920c46697585620135b4ecfdfc1ed23e75cc2c2ae9f8502c10b8 \
+    --hash=sha256:cb8bea573863762bbf45d1e13f87c2d2fd32cee2dbd50d050f83f87429c9e1ea \
+    --hash=sha256:cc32b2990fc34380ec2f6195f33a76b6cdaa9eecf09f0c9404b74fc120aef36f \
+    --hash=sha256:ccae5de2a0140d8be6838c331604f91d6fafd0735dbdcee1ac78fc8fbaba76b4 \
+    --hash=sha256:d299797d75cd747e7797b1b41817111406b8b10a4f88b6e8fe5b5e59598b43b0 \
+    --hash=sha256:e04b622bb8a88f10e439084486f2f6349bf4d50605ac3e445869c7ea5cf0fa8c \
+    --hash=sha256:e11d7ea4d24f0a262bccf9a7cd6284c976c5369dac21db237cff59586045ab9f \
+    --hash=sha256:e21f66748ab725ade40fa7af8ec8b5019c68ab00b929f6643e1b1af461eddb60 \
+    --hash=sha256:eb60b026d8ad0c97917cb81d3662d0b39b8ff1335e3fabb24984c6acd0c900a2 \
+    --hash=sha256:f021d334f2ca692523aaf7bbf7592ceff70c8594fad853416a81d66b35e3abf9 \
+    --hash=sha256:f552023710d4b93d8fb29a91fadf97de89c5926c6bd758897875435f2a939f33
+    # via
+    #   aim
+    #   alembic
+starlette==0.38.6 \
+    --hash=sha256:4517a1409e2e73ee4951214ba012052b9e16f60e90d73cfb06192c19203bbb05 \
+    --hash=sha256:863a1588f5574e70a821dadefb41e4881ea451a47a3cd1b4df359d4ffefe5ead
+    # via fastapi
 sympy==1.13.2 \
     --hash=sha256:401449d84d07be9d0c7a46a64bd54fe097667d5e7181bfe67ec777be9e01cb13 \
     --hash=sha256:c51d75517712f1aed280d4ce58506a4a88d635d6b5dd48b39102a7ae1f3fcfe9
@@ -1749,7 +2259,9 @@
 tqdm==4.66.5 \
     --hash=sha256:90279a3770753eafc9194a0364852159802111925aa30eb3f9d85b0e805ac7cd \
     --hash=sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad
-    # via tensorflow-datasets
+    # via
+    #   aim
+    #   tensorflow-datasets
 typeguard==2.13.3 \
     --hash=sha256:00edaa8da3a133674796cf5ea87d9f4b4c367d77476e185e80251cc13dfbb8c4 \
     --hash=sha256:5e3e3be01e887e7eafae5af63d1f36c849aaa94e3a0112097312aabfa16284f1
@@ -1758,22 +2270,36 @@
     --hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \
     --hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8
     # via
+    #   alembic
+    #   anyio
     #   chex
     #   clu
     #   etils
+    #   fastapi
     #   flashbax
     #   flax
     #   optree
     #   orbax-checkpoint
+    #   pydantic
+    #   pydantic-core
+    #   sqlalchemy
+    #   starlette
     #   tensorflow
+    #   uvicorn
 tzdata==2024.1 \
     --hash=sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd \
     --hash=sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252
     # via pandas
-urllib3==2.2.2 \
-    --hash=sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472 \
-    --hash=sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168
-    # via requests
+urllib3==1.26.20 \
+    --hash=sha256:0ed14ccfbf1c30a9072c7ca157e4319b70d65f623e91e7b32fadb2853431016e \
+    --hash=sha256:40c2dc0c681e47eb8f90e7e27bf6ff7df2e677421fd46756da1161c39ca70d32
+    # via
+    #   botocore
+    #   requests
+uvicorn==0.30.6 \
+    --hash=sha256:4b15decdda1e72be08209e860a1e10e92439ad5b97cf44cc945fcbee66fc5788 \
+    --hash=sha256:65fd46fe3fda5bdc1b03b94eb634923ff18cd35b2f084813ea79d1f103f711b5
+    # via aim
 validators==0.34.0 \
     --hash=sha256:647fe407b45af9a74d245b943b18e6a816acf4926974278f6dd617778e1e781f \
     --hash=sha256:c804b476e3e6d3786fa07a30073a4ef694e617805eb1946ceee3fe5a9b8b1321
@@ -1810,6 +2336,94 @@
     --hash=sha256:f8b2918c19e0d48f5f20df458c84692e2a054f02d9df25e6c3c930063eca64c1 \
     --hash=sha256:fb223456db6e5f7bd9bbd5cd969f05aae82ae21acc00643b60d81c770abd402b
     # via mkdocs
+websockets==13.1 \
+    --hash=sha256:004280a140f220c812e65f36944a9ca92d766b6cc4560be652a0a3883a79ed8a \
+    --hash=sha256:035233b7531fb92a76beefcbf479504db8c72eb3bff41da55aecce3a0f729e54 \
+    --hash=sha256:149e622dc48c10ccc3d2760e5f36753db9cacf3ad7bc7bbbfd7d9c819e286f23 \
+    --hash=sha256:163e7277e1a0bd9fb3c8842a71661ad19c6aa7bb3d6678dc7f89b17fbcc4aeb7 \
+    --hash=sha256:18503d2c5f3943e93819238bf20df71982d193f73dcecd26c94514f417f6b135 \
+    --hash=sha256:1971e62d2caa443e57588e1d82d15f663b29ff9dfe7446d9964a4b6f12c1e700 \
+    --hash=sha256:204e5107f43095012b00f1451374693267adbb832d29966a01ecc4ce1db26faf \
+    --hash=sha256:2510c09d8e8df777177ee3d40cd35450dc169a81e747455cc4197e63f7e7bfe5 \
+    --hash=sha256:25c35bf84bf7c7369d247f0b8cfa157f989862c49104c5cf85cb5436a641d93e \
+    --hash=sha256:2f85cf4f2a1ba8f602298a853cec8526c2ca42a9a4b947ec236eaedb8f2dc80c \
+    --hash=sha256:308e20f22c2c77f3f39caca508e765f8725020b84aa963474e18c59accbf4c02 \
+    --hash=sha256:325b1ccdbf5e5725fdcb1b0e9ad4d2545056479d0eee392c291c1bf76206435a \
+    --hash=sha256:327b74e915cf13c5931334c61e1a41040e365d380f812513a255aa804b183418 \
+    --hash=sha256:346bee67a65f189e0e33f520f253d5147ab76ae42493804319b5716e46dddf0f \
+    --hash=sha256:38377f8b0cdeee97c552d20cf1865695fcd56aba155ad1b4ca8779a5b6ef4ac3 \
+    --hash=sha256:3c78383585f47ccb0fcf186dcb8a43f5438bd7d8f47d69e0b56f71bf431a0a68 \
+    --hash=sha256:4059f790b6ae8768471cddb65d3c4fe4792b0ab48e154c9f0a04cefaabcd5978 \
+    --hash=sha256:459bf774c754c35dbb487360b12c5727adab887f1622b8aed5755880a21c4a20 \
+    --hash=sha256:463e1c6ec853202dd3657f156123d6b4dad0c546ea2e2e38be2b3f7c5b8e7295 \
+    --hash=sha256:4676df3fe46956fbb0437d8800cd5f2b6d41143b6e7e842e60554398432cf29b \
+    --hash=sha256:485307243237328c022bc908b90e4457d0daa8b5cf4b3723fd3c4a8012fce4c6 \
+    --hash=sha256:48a2ef1381632a2f0cb4efeff34efa97901c9fbc118e01951ad7cfc10601a9bb \
+    --hash=sha256:4b889dbd1342820cc210ba44307cf75ae5f2f96226c0038094455a96e64fb07a \
+    --hash=sha256:586a356928692c1fed0eca68b4d1c2cbbd1ca2acf2ac7e7ebd3b9052582deefa \
+    --hash=sha256:58cf7e75dbf7e566088b07e36ea2e3e2bd5676e22216e4cad108d4df4a7402a0 \
+    --hash=sha256:5993260f483d05a9737073be197371940c01b257cc45ae3f1d5d7adb371b266a \
+    --hash=sha256:5dd6da9bec02735931fccec99d97c29f47cc61f644264eb995ad6c0c27667238 \
+    --hash=sha256:5f2e75431f8dc4a47f31565a6e1355fb4f2ecaa99d6b89737527ea917066e26c \
+    --hash=sha256:5f9fee94ebafbc3117c30be1844ed01a3b177bb6e39088bc6b2fa1dc15572084 \
+    --hash=sha256:61fc0dfcda609cda0fc9fe7977694c0c59cf9d749fbb17f4e9483929e3c48a19 \
+    --hash=sha256:624459daabeb310d3815b276c1adef475b3e6804abaf2d9d2c061c319f7f187d \
+    --hash=sha256:62d516c325e6540e8a57b94abefc3459d7dab8ce52ac75c96cad5549e187e3a7 \
+    --hash=sha256:6548f29b0e401eea2b967b2fdc1c7c7b5ebb3eeb470ed23a54cd45ef078a0db9 \
+    --hash=sha256:6d2aad13a200e5934f5a6767492fb07151e1de1d6079c003ab31e1823733ae79 \
+    --hash=sha256:6d6855bbe70119872c05107e38fbc7f96b1d8cb047d95c2c50869a46c65a8e96 \
+    --hash=sha256:70c5be9f416aa72aab7a2a76c90ae0a4fe2755c1816c153c1a2bcc3333ce4ce6 \
+    --hash=sha256:730f42125ccb14602f455155084f978bd9e8e57e89b569b4d7f0f0c17a448ffe \
+    --hash=sha256:7a43cfdcddd07f4ca2b1afb459824dd3c6d53a51410636a2c7fc97b9a8cf4842 \
+    --hash=sha256:7bd6abf1e070a6b72bfeb71049d6ad286852e285f146682bf30d0296f5fbadfa \
+    --hash=sha256:7c1e90228c2f5cdde263253fa5db63e6653f1c00e7ec64108065a0b9713fa1b3 \
+    --hash=sha256:7c65ffa900e7cc958cd088b9a9157a8141c991f8c53d11087e6fb7277a03f81d \
+    --hash=sha256:80c421e07973a89fbdd93e6f2003c17d20b69010458d3a8e37fb47874bd67d51 \
+    --hash=sha256:82d0ba76371769d6a4e56f7e83bb8e81846d17a6190971e38b5de108bde9b0d7 \
+    --hash=sha256:83f91d8a9bb404b8c2c41a707ac7f7f75b9442a0a876df295de27251a856ad09 \
+    --hash=sha256:87c6e35319b46b99e168eb98472d6c7d8634ee37750d7693656dc766395df096 \
+    --hash=sha256:8d23b88b9388ed85c6faf0e74d8dec4f4d3baf3ecf20a65a47b836d56260d4b9 \
+    --hash=sha256:9156c45750b37337f7b0b00e6248991a047be4aa44554c9886fe6bdd605aab3b \
+    --hash=sha256:91a0fa841646320ec0d3accdff5b757b06e2e5c86ba32af2e0815c96c7a603c5 \
+    --hash=sha256:95858ca14a9f6fa8413d29e0a585b31b278388aa775b8a81fa24830123874678 \
+    --hash=sha256:95df24ca1e1bd93bbca51d94dd049a984609687cb2fb08a7f2c56ac84e9816ea \
+    --hash=sha256:9b37c184f8b976f0c0a231a5f3d6efe10807d41ccbe4488df8c74174805eea7d \
+    --hash=sha256:9b6f347deb3dcfbfde1c20baa21c2ac0751afaa73e64e5b693bb2b848efeaa49 \
+    --hash=sha256:9d75baf00138f80b48f1eac72ad1535aac0b6461265a0bcad391fc5aba875cfc \
+    --hash=sha256:9ef8aa8bdbac47f4968a5d66462a2a0935d044bf35c0e5a8af152d58516dbeb5 \
+    --hash=sha256:a11e38ad8922c7961447f35c7b17bffa15de4d17c70abd07bfbe12d6faa3e027 \
+    --hash=sha256:a1b54689e38d1279a51d11e3467dd2f3a50f5f2e879012ce8f2d6943f00e83f0 \
+    --hash=sha256:a3b3366087c1bc0a2795111edcadddb8b3b59509d5db5d7ea3fdd69f954a8878 \
+    --hash=sha256:a569eb1b05d72f9bce2ebd28a1ce2054311b66677fcd46cf36204ad23acead8c \
+    --hash=sha256:a7affedeb43a70351bb811dadf49493c9cfd1ed94c9c70095fd177e9cc1541fa \
+    --hash=sha256:a9a396a6ad26130cdae92ae10c36af09d9bfe6cafe69670fd3b6da9b07b4044f \
+    --hash=sha256:a9ab1e71d3d2e54a0aa646ab6d4eebfaa5f416fe78dfe4da2839525dc5d765c6 \
+    --hash=sha256:a9cd1af7e18e5221d2878378fbc287a14cd527fdd5939ed56a18df8a31136bb2 \
+    --hash=sha256:a9dcaf8b0cc72a392760bb8755922c03e17a5a54e08cca58e8b74f6902b433cf \
+    --hash=sha256:b9d7439d7fab4dce00570bb906875734df13d9faa4b48e261c440a5fec6d9708 \
+    --hash=sha256:bcc03c8b72267e97b49149e4863d57c2d77f13fae12066622dc78fe322490fe6 \
+    --hash=sha256:c11d4d16e133f6df8916cc5b7e3e96ee4c44c936717d684a94f48f82edb7c92f \
+    --hash=sha256:c1dca61c6db1166c48b95198c0b7d9c990b30c756fc2923cc66f68d17dc558fd \
+    --hash=sha256:c518e84bb59c2baae725accd355c8dc517b4a3ed8db88b4bc93c78dae2974bf2 \
+    --hash=sha256:c7934fd0e920e70468e676fe7f1b7261c1efa0d6c037c6722278ca0228ad9d0d \
+    --hash=sha256:c7e72ce6bda6fb9409cc1e8164dd41d7c91466fb599eb047cfda72fe758a34a7 \
+    --hash=sha256:c90d6dec6be2c7d03378a574de87af9b1efea77d0c52a8301dd831ece938452f \
+    --hash=sha256:ceec59f59d092c5007e815def4ebb80c2de330e9588e101cf8bd94c143ec78a5 \
+    --hash=sha256:cf1781ef73c073e6b0f90af841aaf98501f975d306bbf6221683dd594ccc52b6 \
+    --hash=sha256:d04f13a1d75cb2b8382bdc16ae6fa58c97337253826dfe136195b7f89f661557 \
+    --hash=sha256:d6d300f8ec35c24025ceb9b9019ae9040c1ab2f01cddc2bcc0b518af31c75c14 \
+    --hash=sha256:d8dbb1bf0c0a4ae8b40bdc9be7f644e2f3fb4e8a9aca7145bfa510d4a374eeb7 \
+    --hash=sha256:de58647e3f9c42f13f90ac7e5f58900c80a39019848c5547bc691693098ae1bd \
+    --hash=sha256:deeb929efe52bed518f6eb2ddc00cc496366a14c726005726ad62c2dd9017a3c \
+    --hash=sha256:df01aea34b6e9e33572c35cd16bae5a47785e7d5c8cb2b54b2acdb9678315a17 \
+    --hash=sha256:e2620453c075abeb0daa949a292e19f56de518988e079c36478bacf9546ced23 \
+    --hash=sha256:e4450fc83a3df53dec45922b576e91e94f5578d06436871dce3a6be38e40f5db \
+    --hash=sha256:e54affdeb21026329fb0744ad187cf812f7d3c2aa702a5edb562b325191fcab6 \
+    --hash=sha256:e9875a0143f07d74dc5e1ded1c4581f0d9f7ab86c78994e2ed9e95050073c94d \
+    --hash=sha256:f1c3cf67185543730888b20682fb186fc8d0fa6f07ccc3ef4390831ab4b388d9 \
+    --hash=sha256:f48c749857f8fb598fb890a75f540e3221d0976ed0bf879cf3c7eef34151acee \
+    --hash=sha256:f779498eeec470295a2b1a5d97aa1bc9814ecd25e1eb637bd9d1c73a327387f6
+    # via aim
 werkzeug==3.0.4 \
     --hash=sha256:02c9eb92b7d6c06f31a782811505d2157837cea66aaede3e217c7c27c039476c \
     --hash=sha256:34f2371506b250df4d4f84bfe7b0921e4762525762bbd936614909fe25cd7306
diff --git a/tools/python/requirements.txt b/tools/python/requirements.txt
index 2b806f5..241425a 100644
--- a/tools/python/requirements.txt
+++ b/tools/python/requirements.txt
@@ -44,3 +44,6 @@
 
 # Experience buffer for reinforcement learning
 flashbax
+
+# Experiment tracking
+aim
diff --git a/tools/python/runtime_binary.sh b/tools/python/runtime_binary.sh
index 8498408..95283e0 100755
--- a/tools/python/runtime_binary.sh
+++ b/tools/python/runtime_binary.sh
@@ -27,6 +27,9 @@
     LD_LIBRARY_PATH+=":${path}/../gtk_runtime/lib/x86_64-linux-gnu"
     LD_LIBRARY_PATH+=":${path}/../gtk_runtime/usr/lib/x86_64-linux-gnu"
     LD_LIBRARY_PATH+=":${path}/../gtk_runtime/usr/lib"
+    if [[ -e "${path}/../pip_deps_nvidia_nccl_cu12" ]]; then
+      LD_LIBRARY_PATH+=":${path}/../pip_deps_nvidia_nccl_cu12/site-packages/nvidia/nccl/lib/"
+    fi
     export LD_LIBRARY_PATH
     break
   fi
diff --git a/tools/python/whl_overrides.json b/tools/python/whl_overrides.json
index 586a9d3..249b856 100644
--- a/tools/python/whl_overrides.json
+++ b/tools/python/whl_overrides.json
@@ -3,6 +3,38 @@
         "sha256": "526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/absl_py-2.1.0-py3-none-any.whl"
     },
+    "aim==3.24.0": {
+        "sha256": "1789bcc9edf69ae90bb18a6f3fa4862095b8020eadc7e5e0e93ba4f242250947",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/aim-3.24.0-cp39-cp39-manylinux_2_28_x86_64.whl"
+    },
+    "aim_ui==3.24.0": {
+        "sha256": "b62fbcb0ea4b036b99985d4c649220cdbe4523f05408dcb2324ffc7e4539f321",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/aim_ui-3.24.0-py3-none-any.whl"
+    },
+    "aimrecords==0.0.7": {
+        "sha256": "b9276890891c5fd68f817e20fc5d466a80c01e22fa468eaa979331448a75d601",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/aimrecords-0.0.7-py2.py3-none-any.whl"
+    },
+    "aimrocks==0.5.2": {
+        "sha256": "e6a79076d669dbf221442399da893baff6c1549031edbb5d556042d1b9b6525c",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/aimrocks-0.5.2-cp39-cp39-manylinux_2_24_x86_64.whl"
+    },
+    "aiofiles==24.1.0": {
+        "sha256": "b4ec55f4195e3eb5d7abd1bf7e061763e864dd4954231fb8539a0ef8bb8260e5",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/aiofiles-24.1.0-py3-none-any.whl"
+    },
+    "alembic==1.13.3": {
+        "sha256": "908e905976d15235fae59c9ac42c4c5b75cfcefe3d27c0fbf7ae15a37715d80e",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/alembic-1.13.3-py3-none-any.whl"
+    },
+    "annotated_types==0.7.0": {
+        "sha256": "1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/annotated_types-0.7.0-py3-none-any.whl"
+    },
+    "anyio==4.6.0": {
+        "sha256": "c7d2e9d63e31599eeb636c8c5c03a7e108d73b345f064f1c19fdc87b79036a9a",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/anyio-4.6.0-py3-none-any.whl"
+    },
     "array_record==0.5.1": {
         "sha256": "9922862216a9d3be76fdc27968af1ec0ea20f986329998ba45b0f01ee3e646fa",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/array_record-0.5.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
@@ -11,6 +43,10 @@
         "sha256": "c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/astunparse-1.6.3-py2.py3-none-any.whl"
     },
+    "base58==2.0.1": {
+        "sha256": "447adc750d6b642987ffc6d397ecd15a799852d5f6a1d308d384500243825058",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/base58-2.0.1-py3-none-any.whl"
+    },
     "blinker==1.8.2": {
         "sha256": "1779309f71bf239144b9399d06ae925637cf6634cf6bd131104184531bf67c01",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/blinker-1.8.2-py3-none-any.whl"
@@ -19,6 +55,18 @@
         "sha256": "c6f33817f866fc67fbeb5df79cd13a8bb592c05c591f3fd7f4f22b824f7afa01",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/bokeh-3.4.3-py3-none-any.whl"
     },
+    "boto3==1.35.27": {
+        "sha256": "3da139ca038032e92086e26d23833b557f0c257520162bfd3d6f580bf8032c86",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/boto3-1.35.27-py3-none-any.whl"
+    },
+    "botocore==1.35.27": {
+        "sha256": "c299c70b5330a8634e032883ce8a72c2c6d9fdbc985d8191199cb86b92e7cbbd",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/botocore-1.35.27-py3-none-any.whl"
+    },
+    "cachetools==5.5.0": {
+        "sha256": "02134e8439cdc2ffb62023ce1debca2944c3f289d66bb17ead3ab3dede74b292",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/cachetools-5.5.0-py3-none-any.whl"
+    },
     "casadi==3.6.6": {
         "sha256": "cc3594b348f306018b142638d0a8c4026f80e996b4e9798fc504899256a7b029",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/casadi-3.6.6-cp39-none-manylinux2014_x86_64.whl"
@@ -27,6 +75,10 @@
         "sha256": "922820b53db7a7257ffbda3f597266d435245903d80737e34f8a45ff3e3230d8",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/certifi-2024.8.30-py3-none-any.whl"
     },
+    "cffi==1.17.1": {
+        "sha256": "cdf5ce3acdfd1661132f2a9c19cac174758dc2352bfe37d98aa7512c6b7178b3",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/cffi-1.17.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
+    },
     "charset_normalizer==3.3.2": {
         "sha256": "b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
@@ -51,6 +103,10 @@
         "sha256": "68a32389b06b82c2fdd68276148d7b9275b5f5cf13e5417e4252f6d1a34f72a2",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/contourpy-1.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
     },
+    "cryptography==43.0.1": {
+        "sha256": "511f4273808ab590912a93ddb4e3914dfd8a388fed883361b02dea3791f292e1",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/cryptography-43.0.1-cp39-abi3-manylinux_2_28_x86_64.whl"
+    },
     "cycler==0.12.1": {
         "sha256": "85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/cycler-0.12.1-py3-none-any.whl"
@@ -63,6 +119,18 @@
         "sha256": "6dc882d355e1e98a5d1a148d6323679dc47c9a5792939b9de72615aa4737eb0b",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/etils-1.5.2-py3-none-any.whl"
     },
+    "exceptiongroup==1.2.2": {
+        "sha256": "3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/exceptiongroup-1.2.2-py3-none-any.whl"
+    },
+    "fastapi==0.115.0": {
+        "sha256": "17ea427674467486e997206a5ab25760f6b09e069f099b96f5b55a32fb6f1631",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/fastapi-0.115.0-py3-none-any.whl"
+    },
+    "filelock==3.16.1": {
+        "sha256": "2082e5703d51fbf98ea75855d9d5527e33d8ff23099bec374a134febee6946b0",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/filelock-3.16.1-py3-none-any.whl"
+    },
     "flashbax==0.1.2": {
         "sha256": "ac50b2580808ce63787da0ae240db14986e3404ade98a33335e6d7a5efe84859",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/flashbax-0.1.2-py3-none-any.whl"
@@ -103,10 +171,18 @@
         "sha256": "b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/google_pasta-0.2.0-py3-none-any.whl"
     },
+    "greenlet==3.1.1": {
+        "sha256": "63e4844797b975b9af3a3fb8f7866ff08775f5426925e1e0bbcfe7932059a12c",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/greenlet-3.1.1-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl"
+    },
     "grpcio==1.66.1": {
         "sha256": "48b0d92d45ce3be2084b92fb5bae2f64c208fea8ceed7fccf6a7b524d3c4942e",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/grpcio-1.66.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
     },
+    "h11==0.14.0": {
+        "sha256": "e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/h11-0.14.0-py3-none-any.whl"
+    },
     "h5py==3.11.0": {
         "sha256": "67462d0669f8f5459529de179f7771bd697389fcb3faab54d63bf788599a48ea",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/h5py-3.11.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
@@ -155,6 +231,10 @@
         "sha256": "bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/jinja2-3.1.4-py3-none-any.whl"
     },
+    "jmespath==1.0.1": {
+        "sha256": "02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/jmespath-1.0.1-py3-none-any.whl"
+    },
     "keras==3.5.0": {
         "sha256": "d37a3c623935713473ceb25241b52bce9d1e0ff5b36e5d0f6f47ed55f8500c9a",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/keras-3.5.0-py3-none-any.whl"
@@ -167,6 +247,10 @@
         "sha256": "c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/libclang-18.1.1-py2.py3-none-manylinux2010_x86_64.whl"
     },
+    "mako==1.3.5": {
+        "sha256": "260f1dbc3a519453a9c856dedfe4beb4e50bd5a26d96386cb6c80856556bb91a",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/Mako-1.3.5-py3-none-any.whl"
+    },
     "markdown==3.7": {
         "sha256": "7eb6df5690b81a1d7942992c97fad2938e956e79df20cbc6186e9c3a77b1c803",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/Markdown-3.7-py3-none-any.whl"
@@ -331,6 +415,18 @@
         "sha256": "5eb499c081ff03ffed8edaa115325edf46eda7f89b53793a4b70dfc72180cc31",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/pycairo-1.26.1-cp39-cp39-manylinux_2_34_x86_64.whl"
     },
+    "pycparser==2.22": {
+        "sha256": "c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/pycparser-2.22-py3-none-any.whl"
+    },
+    "pydantic==2.9.2": {
+        "sha256": "f048cec7b26778210e28a0459867920654d48e5e62db0958433636cde4254f12",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/pydantic-2.9.2-py3-none-any.whl"
+    },
+    "pydantic_core==2.23.4": {
+        "sha256": "1278e0d324f6908e872730c9102b0112477a7f7cf88b308e4fc36ce1bdb6d58c",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
+    },
     "pygments==2.18.0": {
         "sha256": "b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/pygments-2.18.0-py3-none-any.whl"
@@ -371,10 +467,18 @@
         "sha256": "70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/requests-2.32.3-py3-none-any.whl"
     },
+    "restrictedpython==7.2": {
+        "sha256": "139cb41da6e57521745a566d05825f7a09e6a884f7fa922568cff0a70b84ce6b",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/RestrictedPython-7.2-py3-none-any.whl"
+    },
     "rich==13.8.0": {
         "sha256": "2e85306a063b9492dffc86278197a60cbece75bcb766022f3436f567cae11bdc",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/rich-13.8.0-py3-none-any.whl"
     },
+    "s3transfer==0.10.2": {
+        "sha256": "eca1c20de70a39daee580aef4986996620f365c4e0fda6a86100231d62f1bf69",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/s3transfer-0.10.2-py3-none-any.whl"
+    },
     "scipy==1.13.1": {
         "sha256": "637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/scipy-1.13.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
@@ -391,6 +495,18 @@
         "sha256": "8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/six-1.16.0-py2.py3-none-any.whl"
     },
+    "sniffio==1.3.1": {
+        "sha256": "2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/sniffio-1.3.1-py3-none-any.whl"
+    },
+    "sqlalchemy==2.0.35": {
+        "sha256": "890da8cd1941fa3dab28c5bac3b9da8502e7e366f895b3b8e500896f12f94d11",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/SQLAlchemy-2.0.35-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
+    },
+    "starlette==0.38.6": {
+        "sha256": "4517a1409e2e73ee4951214ba012052b9e16f60e90d73cfb06192c19203bbb05",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/starlette-0.38.6-py3-none-any.whl"
+    },
     "sympy==1.13.2": {
         "sha256": "c51d75517712f1aed280d4ce58506a4a88d635d6b5dd48b39102a7ae1f3fcfe9",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/sympy-1.13.2-py3-none-any.whl"
@@ -463,9 +579,13 @@
         "sha256": "9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/tzdata-2024.1-py2.py3-none-any.whl"
     },
-    "urllib3==2.2.2": {
-        "sha256": "a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472",
-        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/urllib3-2.2.2-py3-none-any.whl"
+    "urllib3==1.26.20": {
+        "sha256": "0ed14ccfbf1c30a9072c7ca157e4319b70d65f623e91e7b32fadb2853431016e",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/urllib3-1.26.20-py2.py3-none-any.whl"
+    },
+    "uvicorn==0.30.6": {
+        "sha256": "65fd46fe3fda5bdc1b03b94eb634923ff18cd35b2f084813ea79d1f103f711b5",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/uvicorn-0.30.6-py3-none-any.whl"
     },
     "validators==0.34.0": {
         "sha256": "c804b476e3e6d3786fa07a30073a4ef694e617805eb1946ceee3fe5a9b8b1321",
@@ -475,6 +595,10 @@
         "sha256": "726eef8f8c634ac6584f86c9c53353a010d9f311f6c15a034f3800a7a891d941",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/watchdog-5.0.2-py3-none-manylinux2014_x86_64.whl"
     },
+    "websockets==13.1": {
+        "sha256": "6d2aad13a200e5934f5a6767492fb07151e1de1d6079c003ab31e1823733ae79",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/websockets-13.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
+    },
     "werkzeug==3.0.4": {
         "sha256": "02c9eb92b7d6c06f31a782811505d2157837cea66aaede3e217c7c27c039476c",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/werkzeug-3.0.4-py3-none-any.whl"