Provide a goal to our SAC implementation

This sets us up to go to multiple destinations nicely.

Change-Id: Ie54cd037b710776fe8ae14ae395a9776d37f0296
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/velocity_controller/lqr_plot.py b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
index f8e9137..8df956d 100644
--- a/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
+++ b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
@@ -122,6 +122,7 @@
         return state.q1_apply(
             state.params,
             observation=state.problem.unwrap_angles(jax.numpy.array([X, Y])),
+            R=goal,
             action=jax.numpy.array([0.]),
         )[0]
 
@@ -129,11 +130,12 @@
         return state.q2_apply(
             state.params,
             observation=state.problem.unwrap_angles(jax.numpy.array([X, Y])),
+            R=goal,
             action=jax.numpy.array([0.]),
         )[0]
 
     def lqr_cost(X, Y):
-        x = jax.numpy.array([X, Y])
+        x = jax.numpy.array([X, Y]) - goal
         return -x.T @ jax.numpy.array(P) @ x
 
     def compute_q(params, x, y):
@@ -143,11 +145,13 @@
             state.q1_apply(
                 params,
                 observation=X,
+                R=goal,
                 action=jax.numpy.array([0.]),
             )[0],
             state.q2_apply(
                 params,
                 observation=X,
+                R=goal,
                 action=jax.numpy.array([0.]),
             )[0])
 
@@ -160,10 +164,6 @@
                                                                      grid_Y)
     lqr_cost_grid = jax.vmap(jax.vmap(lqr_cost))(grid_X, grid_Y)
 
-    # TODO(austin): Stuff to figure out:
-    # 3: Make it converge faster.  Use both GPUs better?
-    # 4: Can we feed in a reference position and get it to learn to stabilize there?
-
     # Now compute the two controller surfaces.
     def compute_lqr_U(X, Y):
         x = jax.numpy.array([X, Y])
@@ -174,6 +174,7 @@
         U, _, _, _ = state.pi_apply(rng,
                                     state.params,
                                     observation=state.problem.unwrap_angles(x),
+                                    R=goal,
                                     deterministic=True)
         return U[0]
 
@@ -190,15 +191,18 @@
         U, _, _, _ = state.pi_apply(rng,
                                     params,
                                     observation=state.problem.unwrap_angles(X),
+                                    R=goal,
                                     deterministic=True)
         U_lqr = F @ (goal - X_lqr)
 
         cost = jax.numpy.minimum(
             state.q1_apply(params,
                            observation=state.problem.unwrap_angles(X),
+                           R=goal,
                            action=U),
             state.q2_apply(params,
                            observation=state.problem.unwrap_angles(X),
+                           R=goal,
                            action=U))
 
         U_plot = data.U.at[i, :].set(U)
@@ -289,24 +293,32 @@
         self.fig0, self.axs0 = pylab.subplots(3)
         self.fig0.supxlabel('Seconds')
 
-        self.x, = self.axs0[0].plot([], [], label="x")
-        self.v, = self.axs0[0].plot([], [], label="v")
-        self.x_lqr, = self.axs0[0].plot([], [], label="x_lqr")
-        self.v_lqr, = self.axs0[0].plot([], [], label="v_lqr")
+        self.axs_velocity = self.axs0[0].twinx()
 
-        self.axs0[0].set_ylabel('Velocity')
+        self.x, = self.axs0[0].plot([], [], label="x")
+        self.x_lqr, = self.axs0[0].plot([], [], label="x_lqr")
+
+        self.axs0[0].set_ylabel('Position')
         self.axs0[0].legend()
+        self.axs0[0].grid()
+
+        self.v, = self.axs_velocity.plot([], [], label="v", color='C2')
+        self.v_lqr, = self.axs_velocity.plot([], [], label="v_lqr", color='C3')
+        self.axs_velocity.set_ylabel('Velocity')
+        self.axs_velocity.legend()
 
         self.uaxis, = self.axs0[1].plot([], [], label="U")
         self.uaxis_lqr, = self.axs0[1].plot([], [], label="U_lqr")
 
         self.axs0[1].set_ylabel('Amps')
         self.axs0[1].legend()
+        self.axs0[1].grid()
 
         self.costaxis, = self.axs0[2].plot([], [], label="cost")
         self.costlqraxis, = self.axs0[2].plot([], [], label="cost lqr")
         self.axs0[2].set_ylabel('Cost')
         self.axs0[2].legend()
+        self.axs0[2].grid()
 
         self.costfig = pyplot.figure(figsize=pyplot.figaspect(0.5))
         self.cost3dax = [
@@ -349,6 +361,9 @@
         self.axs0[0].relim()
         self.axs0[0].autoscale_view()
 
+        self.axs_velocity.relim()
+        self.axs_velocity.autoscale_view()
+
         self.axs0[1].relim()
         self.axs0[1].autoscale_view()
 
diff --git a/frc971/control_loops/swerve/velocity_controller/model.py b/frc971/control_loops/swerve/velocity_controller/model.py
index cb6d9ba..0d8a410 100644
--- a/frc971/control_loops/swerve/velocity_controller/model.py
+++ b/frc971/control_loops/swerve/velocity_controller/model.py
@@ -97,9 +97,10 @@
     @nn.compact
     def __call__(self,
                  observation: ArrayLike,
+                 R: ArrayLike,
                  deterministic: bool = False,
                  rng: PRNGKey | None = None):
-        x = observation
+        x = jax.numpy.hstack((observation, R))
         # Apply the dense layers
         for i, hidden_size in enumerate(self.hidden_sizes):
             x = nn.Dense(
@@ -168,9 +169,10 @@
     activation: Callable = nn.activation.tanh
 
     @nn.compact
-    def __call__(self, observation: ArrayLike, action: ArrayLike):
+    def __call__(self, observation: ArrayLike, R: ArrayLike,
+                 action: ArrayLike):
         # Estimate Q with a simple multi layer dense network.
-        x = jax.numpy.hstack((observation, action))
+        x = jax.numpy.hstack((observation, R, action))
         for i, hidden_size in enumerate(self.hidden_sizes):
             x = nn.Dense(
                 name=f'denselayer{i}',
@@ -222,25 +224,29 @@
                  rng: PRNGKey,
                  params: flax.core.FrozenDict[str, typing.Any],
                  observation: ArrayLike,
+                 R: ArrayLike,
                  deterministic: bool = False):
         return self.pi_apply_fn(
             {'params': params['pi']},
             observation=self.problem.unwrap_angles(observation),
+            R=R,
             deterministic=deterministic,
             rngs={'pi': rng})
 
     def q1_apply(self, params: flax.core.FrozenDict[str, typing.Any],
-                 observation: ArrayLike, action: ArrayLike):
+                 observation: ArrayLike, R: ArrayLike, action: ArrayLike):
         return self.q1_apply_fn(
             {'params': params['q1']},
             observation=self.problem.unwrap_angles(observation),
+            R=R,
             action=action)
 
     def q2_apply(self, params: flax.core.FrozenDict[str, typing.Any],
-                 observation: ArrayLike, action: ArrayLike):
+                 observation: ArrayLike, R: ArrayLike, action: ArrayLike):
         return self.q2_apply_fn(
             {'params': params['q2']},
             observation=self.problem.unwrap_angles(observation),
+            R=R,
             action=action)
 
     def pi_apply_gradients(self, step: int, grads):
@@ -389,15 +395,18 @@
         pi_params = pi.init(
             pi_rng,
             observation=jax.numpy.ones([problem.num_unwrapped_states]),
+            R=jax.numpy.ones([problem.num_goals]),
         )['params']
         q1_params = q1.init(
             q1_rng,
             observation=jax.numpy.ones([problem.num_unwrapped_states]),
+            R=jax.numpy.ones([problem.num_goals]),
             action=jax.numpy.ones([problem.num_outputs]),
         )['params']
         q2_params = q2.init(
             q2_rng,
             observation=jax.numpy.ones([problem.num_unwrapped_states]),
+            R=jax.numpy.ones([problem.num_goals]),
             action=jax.numpy.ones([problem.num_outputs]),
         )['params']
 
diff --git a/frc971/control_loops/swerve/velocity_controller/physics.py b/frc971/control_loops/swerve/velocity_controller/physics.py
index d7bf7b4..450b692 100644
--- a/frc971/control_loops/swerve/velocity_controller/physics.py
+++ b/frc971/control_loops/swerve/velocity_controller/physics.py
@@ -44,20 +44,16 @@
     def unwrap_angles(self, X: jax.typing.ArrayLike):
         return X
 
-    @partial(jax.jit, static_argnums=[0])
     def xdot(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike):
         raise NotImplemented("xdot not implemented")
 
-    @partial(jax.jit, static_argnums=[0])
     def cost(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike,
              goal: jax.typing.ArrayLike):
         raise NotImplemented("cost not implemented")
 
-    @partial(jax.jit, static_argnums=[0])
     def random_states(self, rng: PRNGKey, dimensions=None):
         raise NotImplemented("random_states not implemented")
 
-    #@partial(jax.jit, static_argnums=[0])
     def random_actions(self, rng: PRNGKey, dimensions=None):
         """Produces a uniformly random action in the action space."""
         return jax.random.uniform(
@@ -67,6 +63,10 @@
             maxval=self.action_limit,
         )
 
+    def random_goals(self, rng: PRNGKey, dimensions=None):
+        """Produces a random goal in the goal space."""
+        raise NotImplemented("random_states not implemented")
+
 
 class TurretProblem(Problem):
 
@@ -77,22 +77,19 @@
                          num_goals=2,
                          action_limit=30.0)
 
-    @partial(jax.jit, static_argnums=[0])
     def xdot(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike):
         A = jax.numpy.array([[0., 1.], [0., -36.85154548]])
         B = jax.numpy.array([[0.], [56.08534375]])
 
         return A @ X + B @ U
 
-    @partial(jax.jit, static_argnums=[0])
     def cost(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike,
              goal: jax.typing.ArrayLike):
         Q = jax.numpy.array([[2.77777778, 0.], [0., 0.01]])
         R = jax.numpy.array([[0.00694444]])
 
-        return X.T @ Q @ X + U.T @ R @ U
+        return (X - goal).T @ Q @ (X - goal) + U.T @ R @ U
 
-    #@partial(jax.jit, static_argnums=[0])
     def random_states(self, rng: PRNGKey, dimensions=None):
         rng1, rng2 = jax.random.split(rng)
 
@@ -103,3 +100,12 @@
              jax.random.uniform(rng2, (dimensions or FLAGS.num_agents, 1),
                                 minval=-10.0,
                                 maxval=10.0)))
+
+    def random_goals(self, rng: PRNGKey, dimensions=None):
+        """Produces a random goal in the goal space."""
+        return jax.numpy.hstack((
+            jax.random.uniform(rng, (dimensions or FLAGS.num_agents, 1),
+                               minval=-0.1,
+                               maxval=0.1),
+            jax.numpy.zeros((dimensions or FLAGS.num_agents, 1)),
+        ))
diff --git a/frc971/control_loops/swerve/velocity_controller/train.py b/frc971/control_loops/swerve/velocity_controller/train.py
index bbe2f5d..f5ba092 100644
--- a/frc971/control_loops/swerve/velocity_controller/train.py
+++ b/frc971/control_loops/swerve/velocity_controller/train.py
@@ -121,18 +121,22 @@
     actions = data['actions']
     rewards = data['rewards']
     observations2 = data['observations2']
+    R = data['goals']
 
     # Compute the ending actions from the current network.
     action2, logp_pi2, _, _ = state.pi_apply(rng=rng,
                                              params=params,
-                                             observation=observations2)
+                                             observation=observations2,
+                                             R=R)
 
     # Compute target network Q values
     q1_pi_target = state.q1_apply(state.target_params,
                                   observation=observations2,
+                                  R=R,
                                   action=action2)
     q2_pi_target = state.q2_apply(state.target_params,
                                   observation=observations2,
+                                  R=R,
                                   action=action2)
     q_pi_target = jax.numpy.minimum(q1_pi_target, q2_pi_target)
 
@@ -143,8 +147,8 @@
                                            (q_pi_target - alpha * logp_pi2))
 
     # Compute the starting Q values from the Q network being optimized.
-    q1 = state.q1_apply(params, observation=observations1, action=actions)
-    q2 = state.q2_apply(params, observation=observations1, action=actions)
+    q1 = state.q1_apply(params, observation=observations1, R=R, action=actions)
+    q2 = state.q2_apply(params, observation=observations1, R=R, action=actions)
 
     # Mean squared error loss against Bellman backup
     q1_loss = ((q1 - bellman_backup)**2).mean()
@@ -169,17 +173,21 @@
 def compute_loss_pi(state: TrainState, rng: PRNGKey, params, data: ArrayLike):
     """Computes the Soft Actor-Critic loss for pi."""
     observations1 = data['observations1']
+    R = data['goals']
     # TODO(austin): We've got differentiable policy and differentiable physics.  Can we use those here?  Have Q learn the future, not the current step?
 
     # Compute the action
     pi, logp_pi, _, _ = state.pi_apply(rng=rng,
                                        params=params,
-                                       observation=observations1)
+                                       observation=observations1,
+                                       R=R)
     q1_pi = state.q1_apply(jax.lax.stop_gradient(params),
                            observation=observations1,
+                           R=R,
                            action=pi)
     q2_pi = state.q2_apply(jax.lax.stop_gradient(params),
                            observation=observations1,
+                           R=R,
                            action=pi)
 
     # And compute the Q of that action.
@@ -286,12 +294,7 @@
         state.problem.random_states(initialization_rng, FLAGS.num_agents),
         state.sharding)
 
-    R = jax.numpy.hstack((
-        jax.random.uniform(goal_rng, (FLAGS.num_agents, 1),
-                           minval=-0.1,
-                           maxval=0.1),
-        jax.numpy.zeros((FLAGS.num_agents, 1)),
-    ))
+    R = state.problem.random_goals(goal_rng, FLAGS.num_agents)
 
     def loop(i, val):
         """Runs 1 step of our simulation."""
@@ -309,6 +312,7 @@
                 rng=action_rng,
                 params=state.params,
                 observation=observation,
+                R=R,
                 deterministic=False)
             return pi_action
 
@@ -336,6 +340,7 @@
                 'observations2': observation2,
                 'actions': pi_action,
                 'rewards': reward.reshape((FLAGS.num_agents, 1)),
+                'goals': R,
             })
 
         return observation2, pi_rng, replay_buffer_state
@@ -432,6 +437,8 @@
         jax.numpy.zeros((problem.num_outputs, )),
         'rewards':
         jax.numpy.zeros((1, )),
+        'goals':
+        jax.numpy.zeros((problem.num_states, )),
     })
 
     replay_buffer_state_sharding = nn.get_sharding(replay_buffer_state,