Merge changes I04cede20,Ie54cd037,I6756d4ba,I6a01afa5 into main

* changes:
  Prep lqr_plot for plotting swerve
  Provide a goal to our SAC implementation
  Split Soft Actor-Criticl problem out into a class
  Add plotly and tensorflow_probability
diff --git a/frc971/control_loops/swerve/velocity_controller/lqr_plot.py b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
index 30f5630..bd6674c 100644
--- a/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
+++ b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
@@ -75,18 +75,6 @@
     return checkpoints.restore_checkpoint(workdir, state)
 
 
-# Hard-coded simulation parameters for the turret.
-dt = 0.005
-A = numpy.matrix([[1., 0.00456639], [0., 0.83172142]])
-B = numpy.matrix([[0.00065992], [0.25610763]])
-
-Q = numpy.matrix([[2.77777778, 0.], [0., 0.01]])
-R = numpy.matrix([[0.00694444]])
-
-# Compute the optimal LQR cost + controller.
-F, P = controls.dlqr(A, B, Q, R, optimal_cost_function=True)
-
-
 def generate_data(step=None):
     grid_X = numpy.arange(-1, 1, 0.1)
     grid_Y = numpy.arange(-10, 10, 0.1)
@@ -94,21 +82,18 @@
     grid_X = jax.numpy.array(grid_X)
     grid_Y = jax.numpy.array(grid_Y)
     # Load the training state.
-    physics_constants = jax_dynamics.Coefficients()
+    problem = physics.TurretProblem()
 
     rng = jax.random.key(0)
     rng, init_rng = jax.random.split(rng)
 
-    state = create_train_state(init_rng, physics_constants,
-                               FLAGS.q_learning_rate, FLAGS.pi_learning_rate)
+    state = create_train_state(init_rng, problem, FLAGS.q_learning_rate,
+                               FLAGS.pi_learning_rate)
 
     state = restore_checkpoint(state, FLAGS.workdir)
     if step is not None and state.step == step:
         return None
 
-    print('F:', F)
-    print('P:', P)
-
     X = jax.numpy.array([1.0, 0.0])
     X_lqr = X.copy()
     goal = jax.numpy.array([0.0, 0.0])
@@ -121,34 +106,38 @@
     def compute_q1(X, Y):
         return state.q1_apply(
             state.params,
-            unwrap_angles(jax.numpy.array([X, Y])),
-            jax.numpy.array([0.]),
+            observation=state.problem.unwrap_angles(jax.numpy.array([X, Y])),
+            R=goal,
+            action=jax.numpy.array([0.]),
         )[0]
 
     def compute_q2(X, Y):
         return state.q2_apply(
             state.params,
-            unwrap_angles(jax.numpy.array([X, Y])),
-            jax.numpy.array([0.]),
+            observation=state.problem.unwrap_angles(jax.numpy.array([X, Y])),
+            R=goal,
+            action=jax.numpy.array([0.]),
         )[0]
 
     def lqr_cost(X, Y):
-        x = jax.numpy.array([X, Y])
-        return -x.T @ jax.numpy.array(P) @ x
+        x = jax.numpy.array([X, Y]) - goal
+        return -x.T @ jax.numpy.array(problem.P) @ x
 
     def compute_q(params, x, y):
-        X = unwrap_angles(jax.numpy.array([x, y]))
+        X = state.problem.unwrap_angles(jax.numpy.array([x, y]))
 
         return jax.numpy.minimum(
             state.q1_apply(
                 params,
-                X,
-                jax.numpy.array([0.]),
+                observation=X,
+                R=goal,
+                action=jax.numpy.array([0.]),
             )[0],
             state.q2_apply(
                 params,
-                X,
-                jax.numpy.array([0.]),
+                observation=X,
+                R=goal,
+                action=jax.numpy.array([0.]),
             )[0])
 
     cost_grid1 = jax.vmap(jax.vmap(compute_q1))(grid_X, grid_Y)
@@ -160,20 +149,17 @@
                                                                      grid_Y)
     lqr_cost_grid = jax.vmap(jax.vmap(lqr_cost))(grid_X, grid_Y)
 
-    # TODO(austin): Stuff to figure out:
-    # 3: Make it converge faster.  Use both GPUs better?
-    # 4: Can we feed in a reference position and get it to learn to stabilize there?
-
     # Now compute the two controller surfaces.
     def compute_lqr_U(X, Y):
         x = jax.numpy.array([X, Y])
-        return (-jax.numpy.array(F.reshape((2, ))) @ x)[0]
+        return (-jax.numpy.array(problem.F.reshape((2, ))) @ x)[0]
 
     def compute_pi_U(X, Y):
         x = jax.numpy.array([X, Y])
         U, _, _, _ = state.pi_apply(rng,
                                     state.params,
-                                    observation=unwrap_angles(x),
+                                    observation=state.problem.unwrap_angles(x),
+                                    R=goal,
                                     deterministic=True)
         return U[0]
 
@@ -185,16 +171,24 @@
     # Now simulate the robot, accumulating up things to plot.
     def loop(i, val):
         X, X_lqr, data, params = val
-        t = data.t.at[i].set(i * dt)
+        t = data.t.at[i].set(i * problem.dt)
 
         U, _, _, _ = state.pi_apply(rng,
                                     params,
-                                    observation=unwrap_angles(X),
+                                    observation=state.problem.unwrap_angles(X),
+                                    R=goal,
                                     deterministic=True)
-        U_lqr = F @ (goal - X_lqr)
+        U_lqr = problem.F @ (goal - X_lqr)
 
-        cost = jax.numpy.minimum(state.q1_apply(params, unwrap_angles(X), U),
-                                 state.q2_apply(params, unwrap_angles(X), U))
+        cost = jax.numpy.minimum(
+            state.q1_apply(params,
+                           observation=state.problem.unwrap_angles(X),
+                           R=goal,
+                           action=U),
+            state.q2_apply(params,
+                           observation=state.problem.unwrap_angles(X),
+                           R=goal,
+                           action=U))
 
         U_plot = data.U.at[i, :].set(U)
         U_lqr_plot = data.U_lqr.at[i, :].set(U_lqr)
@@ -203,11 +197,11 @@
         cost_plot = data.cost.at[i, :].set(cost)
         cost_lqr_plot = data.cost_lqr.at[i, :].set(lqr_cost(*X_lqr))
 
-        X = A @ X + B @ U
-        X_lqr = A @ X_lqr + B @ U_lqr
+        X = problem.A @ X + problem.B @ U
+        X_lqr = problem.A @ X_lqr + problem.B @ U_lqr
 
-        reward = data.reward - physics.state_cost(X, U, goal)
-        reward_lqr = data.reward_lqr - physics.state_cost(X_lqr, U_lqr, goal)
+        reward = data.reward - state.problem.cost(X, U, goal)
+        reward_lqr = data.reward_lqr - state.problem.cost(X_lqr, U_lqr, goal)
 
         return X, X_lqr, data._replace(
             t=t,
@@ -230,10 +224,10 @@
     X, X_lqr, data, params = integrate(
         Data(
             t=jax.numpy.zeros((FLAGS.horizon, )),
-            X=jax.numpy.zeros((FLAGS.horizon, physics.NUM_STATES)),
-            X_lqr=jax.numpy.zeros((FLAGS.horizon, physics.NUM_STATES)),
-            U=jax.numpy.zeros((FLAGS.horizon, physics.NUM_OUTPUTS)),
-            U_lqr=jax.numpy.zeros((FLAGS.horizon, physics.NUM_OUTPUTS)),
+            X=jax.numpy.zeros((FLAGS.horizon, state.problem.num_states)),
+            X_lqr=jax.numpy.zeros((FLAGS.horizon, state.problem.num_states)),
+            U=jax.numpy.zeros((FLAGS.horizon, state.problem.num_outputs)),
+            U_lqr=jax.numpy.zeros((FLAGS.horizon, state.problem.num_outputs)),
             cost=jax.numpy.zeros((FLAGS.horizon, 1)),
             cost_lqr=jax.numpy.zeros((FLAGS.horizon, 1)),
             q1_grid=cost_grid1,
@@ -284,24 +278,32 @@
         self.fig0, self.axs0 = pylab.subplots(3)
         self.fig0.supxlabel('Seconds')
 
-        self.x, = self.axs0[0].plot([], [], label="x")
-        self.v, = self.axs0[0].plot([], [], label="v")
-        self.x_lqr, = self.axs0[0].plot([], [], label="x_lqr")
-        self.v_lqr, = self.axs0[0].plot([], [], label="v_lqr")
+        self.axs_velocity = self.axs0[0].twinx()
 
-        self.axs0[0].set_ylabel('Velocity')
+        self.x, = self.axs0[0].plot([], [], label="x")
+        self.x_lqr, = self.axs0[0].plot([], [], label="x_lqr")
+
+        self.axs0[0].set_ylabel('Position')
         self.axs0[0].legend()
+        self.axs0[0].grid()
+
+        self.v, = self.axs_velocity.plot([], [], label="v", color='C2')
+        self.v_lqr, = self.axs_velocity.plot([], [], label="v_lqr", color='C3')
+        self.axs_velocity.set_ylabel('Velocity')
+        self.axs_velocity.legend()
 
         self.uaxis, = self.axs0[1].plot([], [], label="U")
         self.uaxis_lqr, = self.axs0[1].plot([], [], label="U_lqr")
 
         self.axs0[1].set_ylabel('Amps')
         self.axs0[1].legend()
+        self.axs0[1].grid()
 
         self.costaxis, = self.axs0[2].plot([], [], label="cost")
         self.costlqraxis, = self.axs0[2].plot([], [], label="cost lqr")
         self.axs0[2].set_ylabel('Cost')
         self.axs0[2].legend()
+        self.axs0[2].grid()
 
         self.costfig = pyplot.figure(figsize=pyplot.figaspect(0.5))
         self.cost3dax = [
@@ -344,6 +346,9 @@
         self.axs0[0].relim()
         self.axs0[0].autoscale_view()
 
+        self.axs_velocity.relim()
+        self.axs_velocity.autoscale_view()
+
         self.axs0[1].relim()
         self.axs0[1].autoscale_view()
 
diff --git a/frc971/control_loops/swerve/velocity_controller/main.py b/frc971/control_loops/swerve/velocity_controller/main.py
index 7ef687c..a0d248d 100644
--- a/frc971/control_loops/swerve/velocity_controller/main.py
+++ b/frc971/control_loops/swerve/velocity_controller/main.py
@@ -24,6 +24,7 @@
 import jax
 import tensorflow as tf
 from frc971.control_loops.swerve import jax_dynamics
+from frc971.control_loops.swerve.velocity_controller import physics
 
 jax._src.deprecations.accelerate('tracer-hash')
 # Enable the compilation cache
@@ -63,8 +64,8 @@
         FLAGS.workdir,
     )
 
-    physics_constants = jax_dynamics.Coefficients()
-    state = train.train(FLAGS.workdir, physics_constants)
+    problem = physics.TurretProblem()
+    state = train.train(FLAGS.workdir, problem)
 
 
 if __name__ == '__main__':
diff --git a/frc971/control_loops/swerve/velocity_controller/model.py b/frc971/control_loops/swerve/velocity_controller/model.py
index ae67efe..0d8a410 100644
--- a/frc971/control_loops/swerve/velocity_controller/model.py
+++ b/frc971/control_loops/swerve/velocity_controller/model.py
@@ -36,8 +36,8 @@
 
 absl.flags.DEFINE_float(
     'final_q_learning_rate',
-    default=0.00002,
-    help='Training learning rate.',
+    default=0.001,
+    help='Fraction of --q_learning_rate to reduce by by the end.',
 )
 
 absl.flags.DEFINE_float(
@@ -48,8 +48,8 @@
 
 absl.flags.DEFINE_float(
     'final_pi_learning_rate',
-    default=0.00002,
-    help='Training learning rate.',
+    default=0.001,
+    help='Fraction of --pi_learning_rate to reduce by by the end.',
 )
 
 absl.flags.DEFINE_integer(
@@ -92,14 +92,15 @@
     activation: Callable = nn.activation.relu
 
     # Max action we can apply
-    action_limit: float = 30.0
+    action_limit: float = 1
 
     @nn.compact
     def __call__(self,
-                 observations: ArrayLike,
+                 observation: ArrayLike,
+                 R: ArrayLike,
                  deterministic: bool = False,
                  rng: PRNGKey | None = None):
-        x = observations
+        x = jax.numpy.hstack((observation, R))
         # Apply the dense layers
         for i, hidden_size in enumerate(self.hidden_sizes):
             x = nn.Dense(
@@ -168,9 +169,10 @@
     activation: Callable = nn.activation.tanh
 
     @nn.compact
-    def __call__(self, observations, actions):
+    def __call__(self, observation: ArrayLike, R: ArrayLike,
+                 action: ArrayLike):
         # Estimate Q with a simple multi layer dense network.
-        x = jax.numpy.hstack((observations, actions))
+        x = jax.numpy.hstack((observation, R, action))
         for i, hidden_size in enumerate(self.hidden_sizes):
             x = nn.Dense(
                 name=f'denselayer{i}',
@@ -185,8 +187,7 @@
 
 
 class TrainState(flax.struct.PyTreeNode):
-    physics_constants: jax_dynamics.CoefficientsType = flax.struct.field(
-        pytree_node=False)
+    problem: Problem = flax.struct.field(pytree_node=False)
 
     step: int | jax.Array = flax.struct.field(pytree_node=True)
     substep: int | jax.Array = flax.struct.field(pytree_node=True)
@@ -221,23 +222,34 @@
 
     def pi_apply(self,
                  rng: PRNGKey,
-                 params,
-                 observation,
+                 params: flax.core.FrozenDict[str, typing.Any],
+                 observation: ArrayLike,
+                 R: ArrayLike,
                  deterministic: bool = False):
-        return self.pi_apply_fn({'params': params['pi']},
-                                physics.unwrap_angles(observation),
-                                deterministic=deterministic,
-                                rngs={'pi': rng})
+        return self.pi_apply_fn(
+            {'params': params['pi']},
+            observation=self.problem.unwrap_angles(observation),
+            R=R,
+            deterministic=deterministic,
+            rngs={'pi': rng})
 
-    def q1_apply(self, params, observation, action):
-        return self.q1_apply_fn({'params': params['q1']},
-                                physics.unwrap_angles(observation), action)
+    def q1_apply(self, params: flax.core.FrozenDict[str, typing.Any],
+                 observation: ArrayLike, R: ArrayLike, action: ArrayLike):
+        return self.q1_apply_fn(
+            {'params': params['q1']},
+            observation=self.problem.unwrap_angles(observation),
+            R=R,
+            action=action)
 
-    def q2_apply(self, params, observation, action):
-        return self.q2_apply_fn({'params': params['q2']},
-                                physics.unwrap_angles(observation), action)
+    def q2_apply(self, params: flax.core.FrozenDict[str, typing.Any],
+                 observation: ArrayLike, R: ArrayLike, action: ArrayLike):
+        return self.q2_apply_fn(
+            {'params': params['q2']},
+            observation=self.problem.unwrap_angles(observation),
+            R=R,
+            action=action)
 
-    def pi_apply_gradients(self, step, grads):
+    def pi_apply_gradients(self, step: int, grads):
         updates, new_pi_opt_state = self.pi_tx.update(grads, self.pi_opt_state,
                                                       self.params)
         new_params = optax.apply_updates(self.params, updates)
@@ -249,7 +261,7 @@
             pi_opt_state=new_pi_opt_state,
         )
 
-    def q_apply_gradients(self, step, grads):
+    def q_apply_gradients(self, step: int, grads):
         updates, new_q_opt_state = self.q_tx.update(grads, self.q_opt_state,
                                                     self.params)
         new_params = optax.apply_updates(self.params, updates)
@@ -291,9 +303,8 @@
         )
 
     @classmethod
-    def create(cls, *, physics_constants: jax_dynamics.CoefficientsType,
-               params, pi_tx, q_tx, alpha_tx, pi_apply_fn, q1_apply_fn,
-               q2_apply_fn, **kwargs):
+    def create(cls, *, problem: Problem, params, pi_tx, q_tx, alpha_tx,
+               pi_apply_fn, q1_apply_fn, q2_apply_fn, **kwargs):
         """Creates a new instance with ``step=0`` and initialized ``opt_state``."""
 
         pi_tx = optax.multi_transform(
@@ -343,7 +354,7 @@
             length=FLAGS.replay_size)
 
         result = cls(
-            physics_constants=physics_constants,
+            problem=problem,
             step=0,
             substep=0,
             params=params,
@@ -357,7 +368,7 @@
             q_opt_state=q_opt_state,
             alpha_tx=alpha_tx,
             alpha_opt_state=alpha_opt_state,
-            target_entropy=-physics.NUM_STATES,
+            target_entropy=-problem.num_states,
             mesh=mesh,
             sharding=sharding,
             replicated_sharding=replicated_sharding,
@@ -367,10 +378,12 @@
         return jax.device_put(result, replicated_sharding)
 
 
-def create_train_state(rng, physics_constants: jax_dynamics.CoefficientsType,
-                       q_learning_rate, pi_learning_rate):
+def create_train_state(rng: PRNGKey, problem: Problem, q_learning_rate,
+                       pi_learning_rate):
     """Creates initial `TrainState`."""
-    pi = SquashedGaussianMLPActor(action_space=physics.NUM_OUTPUTS)
+    pi = SquashedGaussianMLPActor(activation=nn.activation.gelu,
+                                  action_space=problem.num_outputs,
+                                  action_limit=problem.action_limit)
     # We want q1 and q2 to have different network architectures so they pick up differnet things.
     q1 = MLPQFunction(activation=nn.activation.gelu, hidden_sizes=[128, 256])
     q2 = MLPQFunction(activation=nn.activation.gelu, hidden_sizes=[256, 128])
@@ -381,17 +394,20 @@
 
         pi_params = pi.init(
             pi_rng,
-            jax.numpy.ones([physics.NUM_UNWRAPPED_STATES]),
+            observation=jax.numpy.ones([problem.num_unwrapped_states]),
+            R=jax.numpy.ones([problem.num_goals]),
         )['params']
         q1_params = q1.init(
             q1_rng,
-            jax.numpy.ones([physics.NUM_UNWRAPPED_STATES]),
-            jax.numpy.ones([physics.NUM_OUTPUTS]),
+            observation=jax.numpy.ones([problem.num_unwrapped_states]),
+            R=jax.numpy.ones([problem.num_goals]),
+            action=jax.numpy.ones([problem.num_outputs]),
         )['params']
         q2_params = q2.init(
             q2_rng,
-            jax.numpy.ones([physics.NUM_UNWRAPPED_STATES]),
-            jax.numpy.ones([physics.NUM_OUTPUTS]),
+            observation=jax.numpy.ones([problem.num_unwrapped_states]),
+            R=jax.numpy.ones([problem.num_goals]),
+            action=jax.numpy.ones([problem.num_outputs]),
         )['params']
 
         if FLAGS.alpha < 0.0:
@@ -411,7 +427,7 @@
     alpha_tx = optax.sgd(learning_rate=q_learning_rate)
 
     result = TrainState.create(
-        physics_constants=physics_constants,
+        problem=problem,
         params=init_params(rng),
         pi_tx=pi_tx,
         q_tx=q_tx,
diff --git a/frc971/control_loops/swerve/velocity_controller/physics.py b/frc971/control_loops/swerve/velocity_controller/physics.py
index 9454f14..d6515fe 100644
--- a/frc971/control_loops/swerve/velocity_controller/physics.py
+++ b/frc971/control_loops/swerve/velocity_controller/physics.py
@@ -1,69 +1,122 @@
-import jax
+import jax, numpy
 from functools import partial
-from frc971.control_loops.swerve import dynamics
 from absl import logging
-from frc971.control_loops.swerve import jax_dynamics
+from frc971.control_loops.swerve import dynamics, jax_dynamics
+from frc971.control_loops.python import controls
+from flax.typing import PRNGKey
 
 
-@partial(jax.jit, static_argnums=[0])
-def xdot_physics(physics_constants: jax_dynamics.CoefficientsType, X, U):
-    A = jax.numpy.array([[0., 1.], [0., -36.85154548]])
-    B = jax.numpy.array([[0.], [56.08534375]])
+class Problem(object):
 
-    return A @ X + B @ U
+    def __init__(self, num_states: int, num_unwrapped_states: int,
+                 num_outputs: int, num_goals: int, action_limit: float):
+        self.num_states = num_states
+        self.num_unwrapped_states = num_unwrapped_states
+        self.num_outputs = num_outputs
+        self.num_goals = num_goals
+        self.action_limit = action_limit
+        self.dt = 0.005
 
+    def integrate_dynamics(self, X: jax.typing.ArrayLike,
+                           U: jax.typing.ArrayLike):
+        m = 2  # RK4 steps per interval
+        dt = self.dt / m
 
-def unwrap_angles(X):
-    return X
+        def iteration(i, X):
+            weights = jax.numpy.array([[0.0, 0.5, 0.5, 1.0],
+                                       [1.0, 2.0, 2.0, 1.0]])
 
+            def rk_iteration(i, val):
+                kx_previous, weighted_sum = val
+                kx = self.xdot(X + dt * weights[0, i] * kx_previous, U)
+                weighted_sum += dt * weights[1, i] * kx / 6.0
+                return (kx, weighted_sum)
 
-@jax.jit
-def swerve_cost(X: jax.typing.ArrayLike, U: jax.typing.ArrayLike,
-                goal: jax.typing.ArrayLike):
-
-    Q = jax.numpy.array([[2.77777778, 0.], [0., 0.01]])
-    R = jax.numpy.array([[0.00694444]])
-
-    return (X).T @ Q @ (X) + U.T @ R @ U
-
-
-@partial(jax.jit, static_argnums=[0])
-def integrate_dynamics(physics_constants: jax_dynamics.CoefficientsType, X, U):
-    m = 2  # RK4 steps per interval
-    dt = 0.005 / m
-
-    def iteration(i, X):
-        weights = jax.numpy.array([[0.0, 0.5, 0.5, 1.0], [1.0, 2.0, 2.0, 1.0]])
-
-        def rk_iteration(i, val):
-            kx_previous, weighted_sum = val
-            kx = xdot_physics(physics_constants,
-                              X + dt * weights[0, i] * kx_previous, U)
-            weighted_sum += dt * weights[1, i] * kx / 6.0
-            return (kx, weighted_sum)
+            return jax.lax.fori_loop(lower=0,
+                                     upper=4,
+                                     body_fun=rk_iteration,
+                                     init_val=(X, X))[1]
 
         return jax.lax.fori_loop(lower=0,
-                                 upper=4,
-                                 body_fun=rk_iteration,
-                                 init_val=(X, X))[1]
+                                 upper=m,
+                                 body_fun=iteration,
+                                 init_val=X)
 
-    return jax.lax.fori_loop(lower=0, upper=m, body_fun=iteration, init_val=X)
+    def unwrap_angles(self, X: jax.typing.ArrayLike):
+        return X
+
+    def xdot(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike):
+        raise NotImplemented("xdot not implemented")
+
+    def cost(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike,
+             goal: jax.typing.ArrayLike):
+        raise NotImplemented("cost not implemented")
+
+    def random_states(self, rng: PRNGKey, dimensions=None):
+        raise NotImplemented("random_states not implemented")
+
+    def random_actions(self, rng: PRNGKey, dimensions=None):
+        """Produces a uniformly random action in the action space."""
+        return jax.random.uniform(
+            rng,
+            (dimensions or FLAGS.num_agents, self.num_outputs),
+            minval=-self.action_limit,
+            maxval=self.action_limit,
+        )
+
+    def random_goals(self, rng: PRNGKey, dimensions=None):
+        """Produces a random goal in the goal space."""
+        raise NotImplemented("random_states not implemented")
 
 
-state_cost = swerve_cost
-NUM_STATES = 2
-NUM_UNWRAPPED_STATES = 2
-NUM_OUTPUTS = 1
-ACTION_LIMIT = 10.0
+class TurretProblem(Problem):
 
+    def __init__(self):
+        super().__init__(num_states=2,
+                         num_unwrapped_states=2,
+                         num_outputs=1,
+                         num_goals=2,
+                         action_limit=30.0)
+        self.A = numpy.matrix([[1., 0.00456639], [0., 0.83172142]])
+        self.B = numpy.matrix([[0.00065992], [0.25610763]])
 
-def random_states(rng, dimensions=None):
-    rng1, rng2 = jax.random.split(rng)
+        self.Q = numpy.matrix([[2.77777778, 0.], [0., 0.01]])
+        self.R = numpy.matrix([[0.00694444]])
 
-    return jax.numpy.hstack(
-        (jax.random.uniform(rng1, (dimensions or FLAGS.num_agents, 1),
-                            minval=-1.0,
-                            maxval=1.0),
-         jax.random.uniform(rng2, (dimensions or FLAGS.num_agents, 1),
-                            minval=-10.0,
-                            maxval=10.0)))
+        # Compute the optimal LQR cost + controller.
+        self.F, self.P = controls.dlqr(self.A,
+                                       self.B,
+                                       self.Q,
+                                       self.R,
+                                       optimal_cost_function=True)
+
+    def xdot(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike):
+        A_continuous = jax.numpy.array([[0., 1.], [0., -36.85154548]])
+        B_continuous = jax.numpy.array([[0.], [56.08534375]])
+
+        return A_continuous @ X + B_continuous @ U
+
+    def cost(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike,
+             goal: jax.typing.ArrayLike):
+        return (X - goal).T @ jax.numpy.array(
+            self.Q) @ (X - goal) + U.T @ jax.numpy.array(self.R) @ U
+
+    def random_states(self, rng: PRNGKey, dimensions=None):
+        rng1, rng2 = jax.random.split(rng)
+
+        return jax.numpy.hstack(
+            (jax.random.uniform(rng1, (dimensions or FLAGS.num_agents, 1),
+                                minval=-1.0,
+                                maxval=1.0),
+             jax.random.uniform(rng2, (dimensions or FLAGS.num_agents, 1),
+                                minval=-10.0,
+                                maxval=10.0)))
+
+    def random_goals(self, rng: PRNGKey, dimensions=None):
+        """Produces a random goal in the goal space."""
+        return jax.numpy.hstack((
+            jax.random.uniform(rng, (dimensions or FLAGS.num_agents, 1),
+                               minval=-0.1,
+                               maxval=0.1),
+            jax.numpy.zeros((dimensions or FLAGS.num_agents, 1)),
+        ))
diff --git a/frc971/control_loops/swerve/velocity_controller/train.py b/frc971/control_loops/swerve/velocity_controller/train.py
index 5002eee..f5ba092 100644
--- a/frc971/control_loops/swerve/velocity_controller/train.py
+++ b/frc971/control_loops/swerve/velocity_controller/train.py
@@ -78,16 +78,6 @@
 )
 
 
-def random_actions(rng, dimensions=None):
-    """Produces a uniformly random action in the action space."""
-    return jax.random.uniform(
-        rng,
-        (dimensions or FLAGS.num_agents, physics.NUM_OUTPUTS),
-        minval=-physics.ACTION_LIMIT,
-        maxval=physics.ACTION_LIMIT,
-    )
-
-
 def save_checkpoint(state: TrainState, workdir: str):
     """Saves a checkpoint in the workdir."""
     # TODO(austin): use orbax directly.
@@ -131,15 +121,23 @@
     actions = data['actions']
     rewards = data['rewards']
     observations2 = data['observations2']
+    R = data['goals']
 
     # Compute the ending actions from the current network.
     action2, logp_pi2, _, _ = state.pi_apply(rng=rng,
                                              params=params,
-                                             observation=observations2)
+                                             observation=observations2,
+                                             R=R)
 
     # Compute target network Q values
-    q1_pi_target = state.q1_apply(state.target_params, observations2, action2)
-    q2_pi_target = state.q2_apply(state.target_params, observations2, action2)
+    q1_pi_target = state.q1_apply(state.target_params,
+                                  observation=observations2,
+                                  R=R,
+                                  action=action2)
+    q2_pi_target = state.q2_apply(state.target_params,
+                                  observation=observations2,
+                                  R=R,
+                                  action=action2)
     q_pi_target = jax.numpy.minimum(q1_pi_target, q2_pi_target)
 
     alpha = jax.numpy.exp(params['logalpha'])
@@ -149,8 +147,8 @@
                                            (q_pi_target - alpha * logp_pi2))
 
     # Compute the starting Q values from the Q network being optimized.
-    q1 = state.q1_apply(params, observations1, actions)
-    q2 = state.q2_apply(params, observations1, actions)
+    q1 = state.q1_apply(params, observation=observations1, R=R, action=actions)
+    q2 = state.q2_apply(params, observation=observations1, R=R, action=actions)
 
     # Mean squared error loss against Bellman backup
     q1_loss = ((q1 - bellman_backup)**2).mean()
@@ -175,14 +173,22 @@
 def compute_loss_pi(state: TrainState, rng: PRNGKey, params, data: ArrayLike):
     """Computes the Soft Actor-Critic loss for pi."""
     observations1 = data['observations1']
+    R = data['goals']
     # TODO(austin): We've got differentiable policy and differentiable physics.  Can we use those here?  Have Q learn the future, not the current step?
 
     # Compute the action
     pi, logp_pi, _, _ = state.pi_apply(rng=rng,
                                        params=params,
-                                       observation=observations1)
-    q1_pi = state.q1_apply(jax.lax.stop_gradient(params), observations1, pi)
-    q2_pi = state.q2_apply(jax.lax.stop_gradient(params), observations1, pi)
+                                       observation=observations1,
+                                       R=R)
+    q1_pi = state.q1_apply(jax.lax.stop_gradient(params),
+                           observation=observations1,
+                           R=R,
+                           action=pi)
+    q2_pi = state.q2_apply(jax.lax.stop_gradient(params),
+                           observation=observations1,
+                           R=R,
+                           action=pi)
 
     # And compute the Q of that action.
     q_pi = jax.numpy.minimum(q1_pi, q2_pi)
@@ -277,16 +283,19 @@
 
 
 @jax.jit
-def collect_experience(state: TrainState, replay_buffer_state, R: ArrayLike,
+def collect_experience(state: TrainState, replay_buffer_state,
                        step_rng: PRNGKey, step):
     """Collects experience by simulating."""
     pi_rng = jax.random.fold_in(step_rng, step)
     pi_rng, initialization_rng = jax.random.split(pi_rng)
+    pi_rng, goal_rng = jax.random.split(pi_rng)
 
     observation = jax.lax.with_sharding_constraint(
-        physics.random_states(initialization_rng, FLAGS.num_agents),
+        state.problem.random_states(initialization_rng, FLAGS.num_agents),
         state.sharding)
 
+    R = state.problem.random_goals(goal_rng, FLAGS.num_agents)
+
     def loop(i, val):
         """Runs 1 step of our simulation."""
         observation, pi_rng, replay_buffer_state = val
@@ -295,7 +304,7 @@
 
         def true_fn(i):
             # We are at the beginning of the process, pick a random action.
-            return random_actions(action_rng, FLAGS.num_agents)
+            return state.problem.random_actions(action_rng, FLAGS.num_agents)
 
         def false_fn(i):
             # We are past the beginning of the process, use the trained network.
@@ -303,6 +312,7 @@
                 rng=action_rng,
                 params=state.params,
                 observation=observation,
+                R=R,
                 deterministic=False)
             return pi_action
 
@@ -314,26 +324,23 @@
         )
 
         # Compute the destination state.
-        observation2 = jax.vmap(lambda o, pi: physics.integrate_dynamics(
-            state.physics_constants, o, pi),
-                                in_axes=(0, 0))(observation, pi_action)
+        observation2 = jax.vmap(
+            lambda o, pi: state.problem.integrate_dynamics(o, pi),
+            in_axes=(0, 0))(observation, pi_action)
 
         # Soft Actor-Critic is designed to maximize reward.  LQR minimizes
         # cost.  There is nothing which assumes anything about the sign of
         # the reward, so use the negative of the cost.
-        reward = (-jax.vmap(state_cost)(observation2, pi_action, R)).reshape(
-            (FLAGS.num_agents, 1))
+        reward = -jax.vmap(state.problem.cost)(
+            X=observation2, U=pi_action, goal=R)
 
-        logging.info('Observation shape: %s', observation.shape)
-        logging.info('Observation2 shape: %s', observation2.shape)
-        logging.info('Action shape: %s', pi_action.shape)
-        logging.info('Action shape: %s', pi_action.shape)
         replay_buffer_state = state.replay_buffer.add(
             replay_buffer_state, {
                 'observations1': observation,
                 'observations2': observation2,
                 'actions': pi_action,
-                'rewards': reward,
+                'rewards': reward.reshape((FLAGS.num_agents, 1)),
+                'goals': R,
             })
 
         return observation2, pi_rng, replay_buffer_state
@@ -379,9 +386,7 @@
     return rng, state, q_loss, pi_loss, alpha_loss
 
 
-def train(
-    workdir: str, physics_constants: jax_dynamics.CoefficientsType
-) -> train_state.TrainState:
+def train(workdir: str, problem: Problem) -> train_state.TrainState:
     """Trains a Soft Actor-Critic controller."""
     rng = jax.random.key(0)
     rng, r_rng = jax.random.split(rng)
@@ -394,6 +399,14 @@
         'batch_size': FLAGS.batch_size,
         'horizon': FLAGS.horizon,
         'warmup_steps': FLAGS.warmup_steps,
+        'steps': FLAGS.steps,
+        'replay_size': FLAGS.replay_size,
+        'num_agents': FLAGS.num_agents,
+        'polyak': FLAGS.polyak,
+        'gamma': FLAGS.gamma,
+        'alpha': FLAGS.alpha,
+        'final_q_learning_rate': FLAGS.final_q_learning_rate,
+        'final_pi_learning_rate': FLAGS.final_pi_learning_rate,
     }
 
     # Setup TrainState
@@ -406,7 +419,7 @@
         final_learning_rate=FLAGS.final_pi_learning_rate)
     state = create_train_state(
         init_rng,
-        physics_constants,
+        problem,
         q_learning_rate=q_learning_rate,
         pi_learning_rate=pi_learning_rate,
     )
@@ -415,42 +428,23 @@
     state_sharding = nn.get_sharding(state, state.mesh)
     logging.info(state_sharding)
 
-    # TODO(austin): Let the goal change.
-    R = jax.numpy.hstack((
-        jax.random.uniform(rng, (FLAGS.num_agents, 1), minval=0.9, maxval=1.1),
-        jax.random.uniform(rng, (FLAGS.num_agents, 1), minval=-0.1,
-                           maxval=0.1),
-        jax.numpy.zeros((FLAGS.num_agents, 1)),
-    ))
-
     replay_buffer_state = state.replay_buffer.init({
         'observations1':
-        jax.numpy.zeros((physics.NUM_STATES, )),
+        jax.numpy.zeros((problem.num_states, )),
         'observations2':
-        jax.numpy.zeros((physics.NUM_STATES, )),
+        jax.numpy.zeros((problem.num_states, )),
         'actions':
-        jax.numpy.zeros((physics.NUM_OUTPUTS, )),
+        jax.numpy.zeros((problem.num_outputs, )),
         'rewards':
         jax.numpy.zeros((1, )),
+        'goals':
+        jax.numpy.zeros((problem.num_states, )),
     })
 
     replay_buffer_state_sharding = nn.get_sharding(replay_buffer_state,
                                                    state.mesh)
     logging.info(replay_buffer_state_sharding)
 
-    logging.info('Shape: %s, Sharding %s',
-                 replay_buffer_state.experience['observations1'].shape,
-                 replay_buffer_state.experience['observations1'].sharding)
-    logging.info('Shape: %s, Sharding %s',
-                 replay_buffer_state.experience['observations2'].shape,
-                 replay_buffer_state.experience['observations2'].sharding)
-    logging.info('Shape: %s, Sharding %s',
-                 replay_buffer_state.experience['actions'].shape,
-                 replay_buffer_state.experience['actions'].sharding)
-    logging.info('Shape: %s, Sharding %s',
-                 replay_buffer_state.experience['rewards'].shape,
-                 replay_buffer_state.experience['rewards'].sharding)
-
     # Number of gradients to accumulate before doing decent.
     update_after = FLAGS.batch_size // FLAGS.num_agents
 
@@ -462,7 +456,6 @@
         state, replay_buffer_state = collect_experience(
             state,
             replay_buffer_state,
-            R,
             step_rng,
             step,
         )
diff --git a/tools/python/requirements.lock.txt b/tools/python/requirements.lock.txt
index 4fb225e..d465241 100644
--- a/tools/python/requirements.lock.txt
+++ b/tools/python/requirements.lock.txt
@@ -20,6 +20,7 @@
     #   tensorflow
     #   tensorflow-datasets
     #   tensorflow-metadata
+    #   tensorflow-probability
 aim==3.24.0 \
     --hash=sha256:1266fac2453f2d356d2e39c98d15b46ba7a71ecab14cdaa6eaa6d390f5add898 \
     --hash=sha256:1322659dcecb701c264f5067375c74811faf3514bb316e1ca1e53b9de6b62766 \
@@ -400,6 +401,10 @@
     #   mkdocs
     #   tensorflow-datasets
     #   uvicorn
+cloudpickle==3.1.0 \
+    --hash=sha256:81a929b6e3c7335c863c771d673d105f02efdb89dfaba0c90495d1c64796601b \
+    --hash=sha256:fe11acda67f61aaaec473e3afe030feb131d78a43461b718185363384f1ba12e
+    # via tensorflow-probability
 clu==0.0.12 \
     --hash=sha256:0d183e7d25f7dd0700444510a264e24700e2f068bdabd199ed22866f7e54edba \
     --hash=sha256:f71eaa1afbd30f57f7709257ba7e1feb8ad5c1c3dcae3606672a138678bb3ce4
@@ -510,6 +515,10 @@
     --hash=sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30 \
     --hash=sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c
     # via matplotlib
+decorator==5.1.1 \
+    --hash=sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330 \
+    --hash=sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186
+    # via tensorflow-probability
 dm-tree==0.1.8 \
     --hash=sha256:054b461f8176f4bce7a21f7b1870f873a1ced3bdbe1282c816c550bb43c71fa6 \
     --hash=sha256:09964470f76a5201aff2e8f9b26842976de7889300676f927930f6285e256760 \
@@ -557,7 +566,9 @@
     --hash=sha256:ea9e59e0451e7d29aece402d9f908f2e2a80922bcde2ebfd5dcb07750fcbfee8 \
     --hash=sha256:f7ac31b9aecccb2c6e1ab29706f6ded3eba0c2c69c770322c9c685929c3d6afb \
     --hash=sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d
-    # via tensorflow-datasets
+    # via
+    #   tensorflow-datasets
+    #   tensorflow-probability
 etils[enp,epath,epy,etree]==1.5.2 \
     --hash=sha256:6dc882d355e1e98a5d1a148d6323679dc47c9a5792939b9de72615aa4737eb0b \
     --hash=sha256:ba6a3e1aff95c769130776aa176c11540637f5dd881f3b79172a5149b6b1c446
@@ -649,7 +660,9 @@
 gast==0.6.0 \
     --hash=sha256:52b182313f7330389f72b069ba00f174cfe2a06411099547288839c6cbafbd54 \
     --hash=sha256:88fc5300d32c7ac6ca7b515310862f71e6fdf2c029bbec7c66c0f5dd47b6b1fb
-    # via tensorflow
+    # via
+    #   tensorflow
+    #   tensorflow-probability
 ghp-import==2.1.0 \
     --hash=sha256:8337dd7b50877f163d4c0289bc1f1c7f127550241988d568c1db512c4324a619 \
     --hash=sha256:9c535c4c61193c2df8871222567d7fd7e5014d835f97dc7b7439069e2413d343
@@ -1335,6 +1348,7 @@
     #   tensorboard
     #   tensorflow
     #   tensorflow-datasets
+    #   tensorflow-probability
     #   tensorstore
 nvidia-cublas-cu12==12.6.1.4 \
     --hash=sha256:5dd125ece5469dbdceebe2e9536ad8fc4abd38aa394a7ace42fc8a930a1e81e3 \
@@ -1534,6 +1548,7 @@
     #   keras
     #   matplotlib
     #   mkdocs
+    #   plotly
     #   tensorboard
     #   tensorflow
 pandas==2.2.2 \
@@ -1666,6 +1681,10 @@
     # via
     #   mkdocs-get-deps
     #   yapf
+plotly==5.24.1 \
+    --hash=sha256:dbc8ac8339d248a4bcc36e08a5659bacfe1b079390b8953533f4eb22169b4bae \
+    --hash=sha256:f67073a1e637eb0dc3e46324d9d51e2fe76e9727c892dde64ddf1e1b51f29089
+    # via -r tools/python/requirements.txt
 promise==2.3 \
     --hash=sha256:dfd18337c523ba4b6a58801c164c1904a9d4d1b1747c7d5dbf45b693a49d93d0
     # via tensorflow-datasets
@@ -2079,6 +2098,7 @@
     #   python-dateutil
     #   tensorboard
     #   tensorflow
+    #   tensorflow-probability
 sniffio==1.3.1 \
     --hash=sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2 \
     --hash=sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc
@@ -2148,6 +2168,10 @@
     --hash=sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c \
     --hash=sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f
     # via -r tools/python/requirements.txt
+tenacity==9.0.0 \
+    --hash=sha256:807f37ca97d62aa361264d497b0e31e92b8027044942bfa756160d908320d73b \
+    --hash=sha256:93de0c98785b27fcf659856aa9f54bfbd399e29969b0621bc7f762bd441b4539
+    # via plotly
 tensorboard==2.17.1 \
     --hash=sha256:253701a224000eeca01eee6f7e978aea7b408f60b91eb0babdb04e78947b773e
     # via tensorflow
@@ -2173,7 +2197,9 @@
     --hash=sha256:e8d26d6c24ccfb139db1306599257ca8f5cfe254ef2d023bfb667f374a17a64d \
     --hash=sha256:ee18b4fcd627c5e872eabb25092af6c808b6ec77948662c88fc5c89a60eb0211 \
     --hash=sha256:ef615c133cf4d592a073feda634ccbeb521a554be57de74f8c318d38febbeab5
-    # via -r tools/python/requirements.txt
+    # via
+    #   -r tools/python/requirements.txt
+    #   tf-keras
 tensorflow-datasets==4.9.3 \
     --hash=sha256:09cd60eccab0d5a9d15f53e76ee0f1b530ee5aa3665e42be621a4810d9fa5db6 \
     --hash=sha256:90390077dde2c9e4e240754ddfc5bb50b482946d421c8a34677c3afdb0463427
@@ -2199,6 +2225,9 @@
 tensorflow-metadata==1.15.0 \
     --hash=sha256:cb84d8e159128aeae7b3f6013ccd7969c69d2e6d1a7b255dbfa6f5344d962986
     # via tensorflow-datasets
+tensorflow-probability==0.24.0 \
+    --hash=sha256:8c1774683e38359dbcaf3697e79b7e6a4e69b9c7b3679e78ee18f43e59e5759b
+    # via -r tools/python/requirements.txt
 tensorstore==0.1.64 \
     --hash=sha256:1a78aedbddccc09ea283b145496da03dbc7eb8693ae4e01074ed791d72b7eac2 \
     --hash=sha256:24a4cebaf9d0e75d494342948f68edc971d6bb90e23192ddf8d98397fb1ff3cb \
@@ -2231,6 +2260,10 @@
     # via
     #   tensorflow
     #   tensorflow-datasets
+tf-keras==2.17.0 \
+    --hash=sha256:cc97717e4dc08487f327b0740a984043a9e0123c7a4e21206711669d3ec41c88 \
+    --hash=sha256:fda97c18da30da0f72a5a7e80f3eee343b09f4c206dad6c57c944fb2cd18560e
+    # via -r tools/python/requirements.txt
 toml==0.10.2 \
     --hash=sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b \
     --hash=sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f
diff --git a/tools/python/requirements.txt b/tools/python/requirements.txt
index 241425a..342d8ab 100644
--- a/tools/python/requirements.txt
+++ b/tools/python/requirements.txt
@@ -41,9 +41,12 @@
 
 tensorflow
 tensorflow_datasets
+tensorflow_probability
+tf_keras
 
 # Experience buffer for reinforcement learning
 flashbax
 
 # Experiment tracking
 aim
+plotly
diff --git a/tools/python/whl_overrides.json b/tools/python/whl_overrides.json
index 249b856..c27c8ac 100644
--- a/tools/python/whl_overrides.json
+++ b/tools/python/whl_overrides.json
@@ -91,6 +91,10 @@
         "sha256": "ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/click-8.1.7-py3-none-any.whl"
     },
+    "cloudpickle==3.1.0": {
+        "sha256": "fe11acda67f61aaaec473e3afe030feb131d78a43461b718185363384f1ba12e",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/cloudpickle-3.1.0-py3-none-any.whl"
+    },
     "clu==0.0.12": {
         "sha256": "0d183e7d25f7dd0700444510a264e24700e2f068bdabd199ed22866f7e54edba",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/clu-0.0.12-py3-none-any.whl"
@@ -111,6 +115,10 @@
         "sha256": "85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/cycler-0.12.1-py3-none-any.whl"
     },
+    "decorator==5.1.1": {
+        "sha256": "b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/decorator-5.1.1-py3-none-any.whl"
+    },
     "dm_tree==0.1.8": {
         "sha256": "181c35521d480d0365f39300542cb6cd7fd2b77351bb43d7acfda15aef63b317",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/dm_tree-0.1.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
@@ -399,6 +407,10 @@
         "sha256": "2d7a1657e36a80ea911db832a8a6ece5ee53d8de21edd5cc5879af6530b1bfee",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/platformdirs-4.2.2-py3-none-any.whl"
     },
+    "plotly==5.24.1": {
+        "sha256": "f67073a1e637eb0dc3e46324d9d51e2fe76e9727c892dde64ddf1e1b51f29089",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/plotly-5.24.1-py3-none-any.whl"
+    },
     "promise==2.3": {
         "sha256": "d10acd69e1aed4de5840e3915edf51c877dfc7c7ae440fd081019edbf62820a4",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/promise-2.3-py3-none-any.whl"
@@ -515,6 +527,10 @@
         "sha256": "024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/tabulate-0.9.0-py3-none-any.whl"
     },
+    "tenacity==9.0.0": {
+        "sha256": "93de0c98785b27fcf659856aa9f54bfbd399e29969b0621bc7f762bd441b4539",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/tenacity-9.0.0-py3-none-any.whl"
+    },
     "tensorboard==2.17.1": {
         "sha256": "253701a224000eeca01eee6f7e978aea7b408f60b91eb0babdb04e78947b773e",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/tensorboard-2.17.1-py3-none-any.whl"
@@ -539,6 +555,10 @@
         "sha256": "cb84d8e159128aeae7b3f6013ccd7969c69d2e6d1a7b255dbfa6f5344d962986",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/tensorflow_metadata-1.15.0-py3-none-any.whl"
     },
+    "tensorflow_probability==0.24.0": {
+        "sha256": "8c1774683e38359dbcaf3697e79b7e6a4e69b9c7b3679e78ee18f43e59e5759b",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/tensorflow_probability-0.24.0-py2.py3-none-any.whl"
+    },
     "tensorstore==0.1.64": {
         "sha256": "72f76231ce12bfd266358a096e9c6000a2d86c1f4f24c3891c29b2edfffc5df4",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/tensorstore-0.1.64-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
@@ -547,6 +567,10 @@
         "sha256": "9297c0df9c99445c2412e832e882a7884038a25617c60cea2ad69488d4040d63",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/termcolor-2.4.0-py3-none-any.whl"
     },
+    "tf_keras==2.17.0": {
+        "sha256": "cc97717e4dc08487f327b0740a984043a9e0123c7a4e21206711669d3ec41c88",
+        "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/tf_keras-2.17.0-py3-none-any.whl"
+    },
     "toml==0.10.2": {
         "sha256": "806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b",
         "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/toml-0.10.2-py2.py3-none-any.whl"