Split Soft Actor-Criticl problem out into a class

This sets us up to optimize multiple, different things.

Change-Id: I6756d4ba5af411eba99825447611ead9de3c5dad
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/velocity_controller/lqr_plot.py b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
index 30f5630..f8e9137 100644
--- a/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
+++ b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
@@ -94,13 +94,13 @@
     grid_X = jax.numpy.array(grid_X)
     grid_Y = jax.numpy.array(grid_Y)
     # Load the training state.
-    physics_constants = jax_dynamics.Coefficients()
+    problem = physics.TurretProblem()
 
     rng = jax.random.key(0)
     rng, init_rng = jax.random.split(rng)
 
-    state = create_train_state(init_rng, physics_constants,
-                               FLAGS.q_learning_rate, FLAGS.pi_learning_rate)
+    state = create_train_state(init_rng, problem, FLAGS.q_learning_rate,
+                               FLAGS.pi_learning_rate)
 
     state = restore_checkpoint(state, FLAGS.workdir)
     if step is not None and state.step == step:
@@ -121,15 +121,15 @@
     def compute_q1(X, Y):
         return state.q1_apply(
             state.params,
-            unwrap_angles(jax.numpy.array([X, Y])),
-            jax.numpy.array([0.]),
+            observation=state.problem.unwrap_angles(jax.numpy.array([X, Y])),
+            action=jax.numpy.array([0.]),
         )[0]
 
     def compute_q2(X, Y):
         return state.q2_apply(
             state.params,
-            unwrap_angles(jax.numpy.array([X, Y])),
-            jax.numpy.array([0.]),
+            observation=state.problem.unwrap_angles(jax.numpy.array([X, Y])),
+            action=jax.numpy.array([0.]),
         )[0]
 
     def lqr_cost(X, Y):
@@ -137,18 +137,18 @@
         return -x.T @ jax.numpy.array(P) @ x
 
     def compute_q(params, x, y):
-        X = unwrap_angles(jax.numpy.array([x, y]))
+        X = state.problem.unwrap_angles(jax.numpy.array([x, y]))
 
         return jax.numpy.minimum(
             state.q1_apply(
                 params,
-                X,
-                jax.numpy.array([0.]),
+                observation=X,
+                action=jax.numpy.array([0.]),
             )[0],
             state.q2_apply(
                 params,
-                X,
-                jax.numpy.array([0.]),
+                observation=X,
+                action=jax.numpy.array([0.]),
             )[0])
 
     cost_grid1 = jax.vmap(jax.vmap(compute_q1))(grid_X, grid_Y)
@@ -173,7 +173,7 @@
         x = jax.numpy.array([X, Y])
         U, _, _, _ = state.pi_apply(rng,
                                     state.params,
-                                    observation=unwrap_angles(x),
+                                    observation=state.problem.unwrap_angles(x),
                                     deterministic=True)
         return U[0]
 
@@ -189,12 +189,17 @@
 
         U, _, _, _ = state.pi_apply(rng,
                                     params,
-                                    observation=unwrap_angles(X),
+                                    observation=state.problem.unwrap_angles(X),
                                     deterministic=True)
         U_lqr = F @ (goal - X_lqr)
 
-        cost = jax.numpy.minimum(state.q1_apply(params, unwrap_angles(X), U),
-                                 state.q2_apply(params, unwrap_angles(X), U))
+        cost = jax.numpy.minimum(
+            state.q1_apply(params,
+                           observation=state.problem.unwrap_angles(X),
+                           action=U),
+            state.q2_apply(params,
+                           observation=state.problem.unwrap_angles(X),
+                           action=U))
 
         U_plot = data.U.at[i, :].set(U)
         U_lqr_plot = data.U_lqr.at[i, :].set(U_lqr)
@@ -206,8 +211,8 @@
         X = A @ X + B @ U
         X_lqr = A @ X_lqr + B @ U_lqr
 
-        reward = data.reward - physics.state_cost(X, U, goal)
-        reward_lqr = data.reward_lqr - physics.state_cost(X_lqr, U_lqr, goal)
+        reward = data.reward - state.problem.cost(X, U, goal)
+        reward_lqr = data.reward_lqr - state.problem.cost(X_lqr, U_lqr, goal)
 
         return X, X_lqr, data._replace(
             t=t,
@@ -230,10 +235,10 @@
     X, X_lqr, data, params = integrate(
         Data(
             t=jax.numpy.zeros((FLAGS.horizon, )),
-            X=jax.numpy.zeros((FLAGS.horizon, physics.NUM_STATES)),
-            X_lqr=jax.numpy.zeros((FLAGS.horizon, physics.NUM_STATES)),
-            U=jax.numpy.zeros((FLAGS.horizon, physics.NUM_OUTPUTS)),
-            U_lqr=jax.numpy.zeros((FLAGS.horizon, physics.NUM_OUTPUTS)),
+            X=jax.numpy.zeros((FLAGS.horizon, state.problem.num_states)),
+            X_lqr=jax.numpy.zeros((FLAGS.horizon, state.problem.num_states)),
+            U=jax.numpy.zeros((FLAGS.horizon, state.problem.num_outputs)),
+            U_lqr=jax.numpy.zeros((FLAGS.horizon, state.problem.num_outputs)),
             cost=jax.numpy.zeros((FLAGS.horizon, 1)),
             cost_lqr=jax.numpy.zeros((FLAGS.horizon, 1)),
             q1_grid=cost_grid1,
diff --git a/frc971/control_loops/swerve/velocity_controller/main.py b/frc971/control_loops/swerve/velocity_controller/main.py
index 7ef687c..a0d248d 100644
--- a/frc971/control_loops/swerve/velocity_controller/main.py
+++ b/frc971/control_loops/swerve/velocity_controller/main.py
@@ -24,6 +24,7 @@
 import jax
 import tensorflow as tf
 from frc971.control_loops.swerve import jax_dynamics
+from frc971.control_loops.swerve.velocity_controller import physics
 
 jax._src.deprecations.accelerate('tracer-hash')
 # Enable the compilation cache
@@ -63,8 +64,8 @@
         FLAGS.workdir,
     )
 
-    physics_constants = jax_dynamics.Coefficients()
-    state = train.train(FLAGS.workdir, physics_constants)
+    problem = physics.TurretProblem()
+    state = train.train(FLAGS.workdir, problem)
 
 
 if __name__ == '__main__':
diff --git a/frc971/control_loops/swerve/velocity_controller/model.py b/frc971/control_loops/swerve/velocity_controller/model.py
index ae67efe..cb6d9ba 100644
--- a/frc971/control_loops/swerve/velocity_controller/model.py
+++ b/frc971/control_loops/swerve/velocity_controller/model.py
@@ -36,8 +36,8 @@
 
 absl.flags.DEFINE_float(
     'final_q_learning_rate',
-    default=0.00002,
-    help='Training learning rate.',
+    default=0.001,
+    help='Fraction of --q_learning_rate to reduce by by the end.',
 )
 
 absl.flags.DEFINE_float(
@@ -48,8 +48,8 @@
 
 absl.flags.DEFINE_float(
     'final_pi_learning_rate',
-    default=0.00002,
-    help='Training learning rate.',
+    default=0.001,
+    help='Fraction of --pi_learning_rate to reduce by by the end.',
 )
 
 absl.flags.DEFINE_integer(
@@ -92,14 +92,14 @@
     activation: Callable = nn.activation.relu
 
     # Max action we can apply
-    action_limit: float = 30.0
+    action_limit: float = 1
 
     @nn.compact
     def __call__(self,
-                 observations: ArrayLike,
+                 observation: ArrayLike,
                  deterministic: bool = False,
                  rng: PRNGKey | None = None):
-        x = observations
+        x = observation
         # Apply the dense layers
         for i, hidden_size in enumerate(self.hidden_sizes):
             x = nn.Dense(
@@ -168,9 +168,9 @@
     activation: Callable = nn.activation.tanh
 
     @nn.compact
-    def __call__(self, observations, actions):
+    def __call__(self, observation: ArrayLike, action: ArrayLike):
         # Estimate Q with a simple multi layer dense network.
-        x = jax.numpy.hstack((observations, actions))
+        x = jax.numpy.hstack((observation, action))
         for i, hidden_size in enumerate(self.hidden_sizes):
             x = nn.Dense(
                 name=f'denselayer{i}',
@@ -185,8 +185,7 @@
 
 
 class TrainState(flax.struct.PyTreeNode):
-    physics_constants: jax_dynamics.CoefficientsType = flax.struct.field(
-        pytree_node=False)
+    problem: Problem = flax.struct.field(pytree_node=False)
 
     step: int | jax.Array = flax.struct.field(pytree_node=True)
     substep: int | jax.Array = flax.struct.field(pytree_node=True)
@@ -221,23 +220,30 @@
 
     def pi_apply(self,
                  rng: PRNGKey,
-                 params,
-                 observation,
+                 params: flax.core.FrozenDict[str, typing.Any],
+                 observation: ArrayLike,
                  deterministic: bool = False):
-        return self.pi_apply_fn({'params': params['pi']},
-                                physics.unwrap_angles(observation),
-                                deterministic=deterministic,
-                                rngs={'pi': rng})
+        return self.pi_apply_fn(
+            {'params': params['pi']},
+            observation=self.problem.unwrap_angles(observation),
+            deterministic=deterministic,
+            rngs={'pi': rng})
 
-    def q1_apply(self, params, observation, action):
-        return self.q1_apply_fn({'params': params['q1']},
-                                physics.unwrap_angles(observation), action)
+    def q1_apply(self, params: flax.core.FrozenDict[str, typing.Any],
+                 observation: ArrayLike, action: ArrayLike):
+        return self.q1_apply_fn(
+            {'params': params['q1']},
+            observation=self.problem.unwrap_angles(observation),
+            action=action)
 
-    def q2_apply(self, params, observation, action):
-        return self.q2_apply_fn({'params': params['q2']},
-                                physics.unwrap_angles(observation), action)
+    def q2_apply(self, params: flax.core.FrozenDict[str, typing.Any],
+                 observation: ArrayLike, action: ArrayLike):
+        return self.q2_apply_fn(
+            {'params': params['q2']},
+            observation=self.problem.unwrap_angles(observation),
+            action=action)
 
-    def pi_apply_gradients(self, step, grads):
+    def pi_apply_gradients(self, step: int, grads):
         updates, new_pi_opt_state = self.pi_tx.update(grads, self.pi_opt_state,
                                                       self.params)
         new_params = optax.apply_updates(self.params, updates)
@@ -249,7 +255,7 @@
             pi_opt_state=new_pi_opt_state,
         )
 
-    def q_apply_gradients(self, step, grads):
+    def q_apply_gradients(self, step: int, grads):
         updates, new_q_opt_state = self.q_tx.update(grads, self.q_opt_state,
                                                     self.params)
         new_params = optax.apply_updates(self.params, updates)
@@ -291,9 +297,8 @@
         )
 
     @classmethod
-    def create(cls, *, physics_constants: jax_dynamics.CoefficientsType,
-               params, pi_tx, q_tx, alpha_tx, pi_apply_fn, q1_apply_fn,
-               q2_apply_fn, **kwargs):
+    def create(cls, *, problem: Problem, params, pi_tx, q_tx, alpha_tx,
+               pi_apply_fn, q1_apply_fn, q2_apply_fn, **kwargs):
         """Creates a new instance with ``step=0`` and initialized ``opt_state``."""
 
         pi_tx = optax.multi_transform(
@@ -343,7 +348,7 @@
             length=FLAGS.replay_size)
 
         result = cls(
-            physics_constants=physics_constants,
+            problem=problem,
             step=0,
             substep=0,
             params=params,
@@ -357,7 +362,7 @@
             q_opt_state=q_opt_state,
             alpha_tx=alpha_tx,
             alpha_opt_state=alpha_opt_state,
-            target_entropy=-physics.NUM_STATES,
+            target_entropy=-problem.num_states,
             mesh=mesh,
             sharding=sharding,
             replicated_sharding=replicated_sharding,
@@ -367,10 +372,12 @@
         return jax.device_put(result, replicated_sharding)
 
 
-def create_train_state(rng, physics_constants: jax_dynamics.CoefficientsType,
-                       q_learning_rate, pi_learning_rate):
+def create_train_state(rng: PRNGKey, problem: Problem, q_learning_rate,
+                       pi_learning_rate):
     """Creates initial `TrainState`."""
-    pi = SquashedGaussianMLPActor(action_space=physics.NUM_OUTPUTS)
+    pi = SquashedGaussianMLPActor(activation=nn.activation.gelu,
+                                  action_space=problem.num_outputs,
+                                  action_limit=problem.action_limit)
     # We want q1 and q2 to have different network architectures so they pick up differnet things.
     q1 = MLPQFunction(activation=nn.activation.gelu, hidden_sizes=[128, 256])
     q2 = MLPQFunction(activation=nn.activation.gelu, hidden_sizes=[256, 128])
@@ -381,17 +388,17 @@
 
         pi_params = pi.init(
             pi_rng,
-            jax.numpy.ones([physics.NUM_UNWRAPPED_STATES]),
+            observation=jax.numpy.ones([problem.num_unwrapped_states]),
         )['params']
         q1_params = q1.init(
             q1_rng,
-            jax.numpy.ones([physics.NUM_UNWRAPPED_STATES]),
-            jax.numpy.ones([physics.NUM_OUTPUTS]),
+            observation=jax.numpy.ones([problem.num_unwrapped_states]),
+            action=jax.numpy.ones([problem.num_outputs]),
         )['params']
         q2_params = q2.init(
             q2_rng,
-            jax.numpy.ones([physics.NUM_UNWRAPPED_STATES]),
-            jax.numpy.ones([physics.NUM_OUTPUTS]),
+            observation=jax.numpy.ones([problem.num_unwrapped_states]),
+            action=jax.numpy.ones([problem.num_outputs]),
         )['params']
 
         if FLAGS.alpha < 0.0:
@@ -411,7 +418,7 @@
     alpha_tx = optax.sgd(learning_rate=q_learning_rate)
 
     result = TrainState.create(
-        physics_constants=physics_constants,
+        problem=problem,
         params=init_params(rng),
         pi_tx=pi_tx,
         q_tx=q_tx,
diff --git a/frc971/control_loops/swerve/velocity_controller/physics.py b/frc971/control_loops/swerve/velocity_controller/physics.py
index 9454f14..d7bf7b4 100644
--- a/frc971/control_loops/swerve/velocity_controller/physics.py
+++ b/frc971/control_loops/swerve/velocity_controller/physics.py
@@ -3,67 +3,103 @@
 from frc971.control_loops.swerve import dynamics
 from absl import logging
 from frc971.control_loops.swerve import jax_dynamics
+from flax.typing import PRNGKey
 
 
-@partial(jax.jit, static_argnums=[0])
-def xdot_physics(physics_constants: jax_dynamics.CoefficientsType, X, U):
-    A = jax.numpy.array([[0., 1.], [0., -36.85154548]])
-    B = jax.numpy.array([[0.], [56.08534375]])
+class Problem(object):
 
-    return A @ X + B @ U
+    def __init__(self, num_states: int, num_unwrapped_states: int,
+                 num_outputs: int, num_goals: int, action_limit: float):
+        self.num_states = num_states
+        self.num_unwrapped_states = num_unwrapped_states
+        self.num_outputs = num_outputs
+        self.num_goals = num_goals
+        self.action_limit = action_limit
 
+    def integrate_dynamics(self, X: jax.typing.ArrayLike,
+                           U: jax.typing.ArrayLike):
+        m = 2  # RK4 steps per interval
+        dt = 0.005 / m
 
-def unwrap_angles(X):
-    return X
+        def iteration(i, X):
+            weights = jax.numpy.array([[0.0, 0.5, 0.5, 1.0],
+                                       [1.0, 2.0, 2.0, 1.0]])
 
+            def rk_iteration(i, val):
+                kx_previous, weighted_sum = val
+                kx = self.xdot(X + dt * weights[0, i] * kx_previous, U)
+                weighted_sum += dt * weights[1, i] * kx / 6.0
+                return (kx, weighted_sum)
 
-@jax.jit
-def swerve_cost(X: jax.typing.ArrayLike, U: jax.typing.ArrayLike,
-                goal: jax.typing.ArrayLike):
-
-    Q = jax.numpy.array([[2.77777778, 0.], [0., 0.01]])
-    R = jax.numpy.array([[0.00694444]])
-
-    return (X).T @ Q @ (X) + U.T @ R @ U
-
-
-@partial(jax.jit, static_argnums=[0])
-def integrate_dynamics(physics_constants: jax_dynamics.CoefficientsType, X, U):
-    m = 2  # RK4 steps per interval
-    dt = 0.005 / m
-
-    def iteration(i, X):
-        weights = jax.numpy.array([[0.0, 0.5, 0.5, 1.0], [1.0, 2.0, 2.0, 1.0]])
-
-        def rk_iteration(i, val):
-            kx_previous, weighted_sum = val
-            kx = xdot_physics(physics_constants,
-                              X + dt * weights[0, i] * kx_previous, U)
-            weighted_sum += dt * weights[1, i] * kx / 6.0
-            return (kx, weighted_sum)
+            return jax.lax.fori_loop(lower=0,
+                                     upper=4,
+                                     body_fun=rk_iteration,
+                                     init_val=(X, X))[1]
 
         return jax.lax.fori_loop(lower=0,
-                                 upper=4,
-                                 body_fun=rk_iteration,
-                                 init_val=(X, X))[1]
+                                 upper=m,
+                                 body_fun=iteration,
+                                 init_val=X)
 
-    return jax.lax.fori_loop(lower=0, upper=m, body_fun=iteration, init_val=X)
+    def unwrap_angles(self, X: jax.typing.ArrayLike):
+        return X
+
+    @partial(jax.jit, static_argnums=[0])
+    def xdot(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike):
+        raise NotImplemented("xdot not implemented")
+
+    @partial(jax.jit, static_argnums=[0])
+    def cost(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike,
+             goal: jax.typing.ArrayLike):
+        raise NotImplemented("cost not implemented")
+
+    @partial(jax.jit, static_argnums=[0])
+    def random_states(self, rng: PRNGKey, dimensions=None):
+        raise NotImplemented("random_states not implemented")
+
+    #@partial(jax.jit, static_argnums=[0])
+    def random_actions(self, rng: PRNGKey, dimensions=None):
+        """Produces a uniformly random action in the action space."""
+        return jax.random.uniform(
+            rng,
+            (dimensions or FLAGS.num_agents, self.num_outputs),
+            minval=-self.action_limit,
+            maxval=self.action_limit,
+        )
 
 
-state_cost = swerve_cost
-NUM_STATES = 2
-NUM_UNWRAPPED_STATES = 2
-NUM_OUTPUTS = 1
-ACTION_LIMIT = 10.0
+class TurretProblem(Problem):
 
+    def __init__(self):
+        super().__init__(num_states=2,
+                         num_unwrapped_states=2,
+                         num_outputs=1,
+                         num_goals=2,
+                         action_limit=30.0)
 
-def random_states(rng, dimensions=None):
-    rng1, rng2 = jax.random.split(rng)
+    @partial(jax.jit, static_argnums=[0])
+    def xdot(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike):
+        A = jax.numpy.array([[0., 1.], [0., -36.85154548]])
+        B = jax.numpy.array([[0.], [56.08534375]])
 
-    return jax.numpy.hstack(
-        (jax.random.uniform(rng1, (dimensions or FLAGS.num_agents, 1),
-                            minval=-1.0,
-                            maxval=1.0),
-         jax.random.uniform(rng2, (dimensions or FLAGS.num_agents, 1),
-                            minval=-10.0,
-                            maxval=10.0)))
+        return A @ X + B @ U
+
+    @partial(jax.jit, static_argnums=[0])
+    def cost(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike,
+             goal: jax.typing.ArrayLike):
+        Q = jax.numpy.array([[2.77777778, 0.], [0., 0.01]])
+        R = jax.numpy.array([[0.00694444]])
+
+        return X.T @ Q @ X + U.T @ R @ U
+
+    #@partial(jax.jit, static_argnums=[0])
+    def random_states(self, rng: PRNGKey, dimensions=None):
+        rng1, rng2 = jax.random.split(rng)
+
+        return jax.numpy.hstack(
+            (jax.random.uniform(rng1, (dimensions or FLAGS.num_agents, 1),
+                                minval=-1.0,
+                                maxval=1.0),
+             jax.random.uniform(rng2, (dimensions or FLAGS.num_agents, 1),
+                                minval=-10.0,
+                                maxval=10.0)))
diff --git a/frc971/control_loops/swerve/velocity_controller/train.py b/frc971/control_loops/swerve/velocity_controller/train.py
index 5002eee..bbe2f5d 100644
--- a/frc971/control_loops/swerve/velocity_controller/train.py
+++ b/frc971/control_loops/swerve/velocity_controller/train.py
@@ -78,16 +78,6 @@
 )
 
 
-def random_actions(rng, dimensions=None):
-    """Produces a uniformly random action in the action space."""
-    return jax.random.uniform(
-        rng,
-        (dimensions or FLAGS.num_agents, physics.NUM_OUTPUTS),
-        minval=-physics.ACTION_LIMIT,
-        maxval=physics.ACTION_LIMIT,
-    )
-
-
 def save_checkpoint(state: TrainState, workdir: str):
     """Saves a checkpoint in the workdir."""
     # TODO(austin): use orbax directly.
@@ -138,8 +128,12 @@
                                              observation=observations2)
 
     # Compute target network Q values
-    q1_pi_target = state.q1_apply(state.target_params, observations2, action2)
-    q2_pi_target = state.q2_apply(state.target_params, observations2, action2)
+    q1_pi_target = state.q1_apply(state.target_params,
+                                  observation=observations2,
+                                  action=action2)
+    q2_pi_target = state.q2_apply(state.target_params,
+                                  observation=observations2,
+                                  action=action2)
     q_pi_target = jax.numpy.minimum(q1_pi_target, q2_pi_target)
 
     alpha = jax.numpy.exp(params['logalpha'])
@@ -149,8 +143,8 @@
                                            (q_pi_target - alpha * logp_pi2))
 
     # Compute the starting Q values from the Q network being optimized.
-    q1 = state.q1_apply(params, observations1, actions)
-    q2 = state.q2_apply(params, observations1, actions)
+    q1 = state.q1_apply(params, observation=observations1, action=actions)
+    q2 = state.q2_apply(params, observation=observations1, action=actions)
 
     # Mean squared error loss against Bellman backup
     q1_loss = ((q1 - bellman_backup)**2).mean()
@@ -181,8 +175,12 @@
     pi, logp_pi, _, _ = state.pi_apply(rng=rng,
                                        params=params,
                                        observation=observations1)
-    q1_pi = state.q1_apply(jax.lax.stop_gradient(params), observations1, pi)
-    q2_pi = state.q2_apply(jax.lax.stop_gradient(params), observations1, pi)
+    q1_pi = state.q1_apply(jax.lax.stop_gradient(params),
+                           observation=observations1,
+                           action=pi)
+    q2_pi = state.q2_apply(jax.lax.stop_gradient(params),
+                           observation=observations1,
+                           action=pi)
 
     # And compute the Q of that action.
     q_pi = jax.numpy.minimum(q1_pi, q2_pi)
@@ -277,16 +275,24 @@
 
 
 @jax.jit
-def collect_experience(state: TrainState, replay_buffer_state, R: ArrayLike,
+def collect_experience(state: TrainState, replay_buffer_state,
                        step_rng: PRNGKey, step):
     """Collects experience by simulating."""
     pi_rng = jax.random.fold_in(step_rng, step)
     pi_rng, initialization_rng = jax.random.split(pi_rng)
+    pi_rng, goal_rng = jax.random.split(pi_rng)
 
     observation = jax.lax.with_sharding_constraint(
-        physics.random_states(initialization_rng, FLAGS.num_agents),
+        state.problem.random_states(initialization_rng, FLAGS.num_agents),
         state.sharding)
 
+    R = jax.numpy.hstack((
+        jax.random.uniform(goal_rng, (FLAGS.num_agents, 1),
+                           minval=-0.1,
+                           maxval=0.1),
+        jax.numpy.zeros((FLAGS.num_agents, 1)),
+    ))
+
     def loop(i, val):
         """Runs 1 step of our simulation."""
         observation, pi_rng, replay_buffer_state = val
@@ -295,7 +301,7 @@
 
         def true_fn(i):
             # We are at the beginning of the process, pick a random action.
-            return random_actions(action_rng, FLAGS.num_agents)
+            return state.problem.random_actions(action_rng, FLAGS.num_agents)
 
         def false_fn(i):
             # We are past the beginning of the process, use the trained network.
@@ -314,26 +320,22 @@
         )
 
         # Compute the destination state.
-        observation2 = jax.vmap(lambda o, pi: physics.integrate_dynamics(
-            state.physics_constants, o, pi),
-                                in_axes=(0, 0))(observation, pi_action)
+        observation2 = jax.vmap(
+            lambda o, pi: state.problem.integrate_dynamics(o, pi),
+            in_axes=(0, 0))(observation, pi_action)
 
         # Soft Actor-Critic is designed to maximize reward.  LQR minimizes
         # cost.  There is nothing which assumes anything about the sign of
         # the reward, so use the negative of the cost.
-        reward = (-jax.vmap(state_cost)(observation2, pi_action, R)).reshape(
-            (FLAGS.num_agents, 1))
+        reward = -jax.vmap(state.problem.cost)(
+            X=observation2, U=pi_action, goal=R)
 
-        logging.info('Observation shape: %s', observation.shape)
-        logging.info('Observation2 shape: %s', observation2.shape)
-        logging.info('Action shape: %s', pi_action.shape)
-        logging.info('Action shape: %s', pi_action.shape)
         replay_buffer_state = state.replay_buffer.add(
             replay_buffer_state, {
                 'observations1': observation,
                 'observations2': observation2,
                 'actions': pi_action,
-                'rewards': reward,
+                'rewards': reward.reshape((FLAGS.num_agents, 1)),
             })
 
         return observation2, pi_rng, replay_buffer_state
@@ -379,9 +381,7 @@
     return rng, state, q_loss, pi_loss, alpha_loss
 
 
-def train(
-    workdir: str, physics_constants: jax_dynamics.CoefficientsType
-) -> train_state.TrainState:
+def train(workdir: str, problem: Problem) -> train_state.TrainState:
     """Trains a Soft Actor-Critic controller."""
     rng = jax.random.key(0)
     rng, r_rng = jax.random.split(rng)
@@ -394,6 +394,14 @@
         'batch_size': FLAGS.batch_size,
         'horizon': FLAGS.horizon,
         'warmup_steps': FLAGS.warmup_steps,
+        'steps': FLAGS.steps,
+        'replay_size': FLAGS.replay_size,
+        'num_agents': FLAGS.num_agents,
+        'polyak': FLAGS.polyak,
+        'gamma': FLAGS.gamma,
+        'alpha': FLAGS.alpha,
+        'final_q_learning_rate': FLAGS.final_q_learning_rate,
+        'final_pi_learning_rate': FLAGS.final_pi_learning_rate,
     }
 
     # Setup TrainState
@@ -406,7 +414,7 @@
         final_learning_rate=FLAGS.final_pi_learning_rate)
     state = create_train_state(
         init_rng,
-        physics_constants,
+        problem,
         q_learning_rate=q_learning_rate,
         pi_learning_rate=pi_learning_rate,
     )
@@ -415,21 +423,13 @@
     state_sharding = nn.get_sharding(state, state.mesh)
     logging.info(state_sharding)
 
-    # TODO(austin): Let the goal change.
-    R = jax.numpy.hstack((
-        jax.random.uniform(rng, (FLAGS.num_agents, 1), minval=0.9, maxval=1.1),
-        jax.random.uniform(rng, (FLAGS.num_agents, 1), minval=-0.1,
-                           maxval=0.1),
-        jax.numpy.zeros((FLAGS.num_agents, 1)),
-    ))
-
     replay_buffer_state = state.replay_buffer.init({
         'observations1':
-        jax.numpy.zeros((physics.NUM_STATES, )),
+        jax.numpy.zeros((problem.num_states, )),
         'observations2':
-        jax.numpy.zeros((physics.NUM_STATES, )),
+        jax.numpy.zeros((problem.num_states, )),
         'actions':
-        jax.numpy.zeros((physics.NUM_OUTPUTS, )),
+        jax.numpy.zeros((problem.num_outputs, )),
         'rewards':
         jax.numpy.zeros((1, )),
     })
@@ -438,19 +438,6 @@
                                                    state.mesh)
     logging.info(replay_buffer_state_sharding)
 
-    logging.info('Shape: %s, Sharding %s',
-                 replay_buffer_state.experience['observations1'].shape,
-                 replay_buffer_state.experience['observations1'].sharding)
-    logging.info('Shape: %s, Sharding %s',
-                 replay_buffer_state.experience['observations2'].shape,
-                 replay_buffer_state.experience['observations2'].sharding)
-    logging.info('Shape: %s, Sharding %s',
-                 replay_buffer_state.experience['actions'].shape,
-                 replay_buffer_state.experience['actions'].sharding)
-    logging.info('Shape: %s, Sharding %s',
-                 replay_buffer_state.experience['rewards'].shape,
-                 replay_buffer_state.experience['rewards'].sharding)
-
     # Number of gradients to accumulate before doing decent.
     update_after = FLAGS.batch_size // FLAGS.num_agents
 
@@ -462,7 +449,6 @@
         state, replay_buffer_state = collect_experience(
             state,
             replay_buffer_state,
-            R,
             step_rng,
             step,
         )