Switch to using tensorflow_probability

This is technically less accurate, but is a lot less to validate since
we are having trouble getting this to converge.

To try to adhere to the "unit gaussian" assumption that neural networks
like, make the physics account for the action limit instead of the actor
network.  We can undo this after things work if we find it doesn't make
it converge better.

Also, switch from talking about "cost" to "reward" to adhere to other
implementations closer too.

Change-Id: I1b8df0a72946020b3e3da5cf46b3787e38cb94c9
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/velocity_controller/BUILD b/frc971/control_loops/swerve/velocity_controller/BUILD
index 7d376bd..c9bec2a 100644
--- a/frc971/control_loops/swerve/velocity_controller/BUILD
+++ b/frc971/control_loops/swerve/velocity_controller/BUILD
@@ -44,6 +44,8 @@
         "@pip//matplotlib",
         "@pip//numpy",
         "@pip//tensorflow",
+        "@pip//tensorflow_probability",
+        "@pip//tf_keras",
     ],
 )
 
diff --git a/frc971/control_loops/swerve/velocity_controller/lqr_plot.py b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
index bd6674c..2db9842 100644
--- a/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
+++ b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
@@ -156,12 +156,12 @@
 
     def compute_pi_U(X, Y):
         x = jax.numpy.array([X, Y])
-        U, _, _, _ = state.pi_apply(rng,
-                                    state.params,
-                                    observation=state.problem.unwrap_angles(x),
-                                    R=goal,
-                                    deterministic=True)
-        return U[0]
+        U, _, _ = state.pi_apply(rng,
+                                 state.params,
+                                 observation=state.problem.unwrap_angles(x),
+                                 R=goal,
+                                 deterministic=True)
+        return U[0] * problem.action_limit
 
     lqr_cost_U = jax.vmap(jax.vmap(compute_lqr_U))(grid_X, grid_Y)
     pi_cost_U = jax.vmap(jax.vmap(compute_pi_U))(grid_X, grid_Y)
@@ -173,22 +173,25 @@
         X, X_lqr, data, params = val
         t = data.t.at[i].set(i * problem.dt)
 
-        U, _, _, _ = state.pi_apply(rng,
-                                    params,
-                                    observation=state.problem.unwrap_angles(X),
-                                    R=goal,
-                                    deterministic=True)
+        normalized_U, _, _ = state.pi_apply(
+            rng,
+            params,
+            observation=state.problem.unwrap_angles(X),
+            R=goal,
+            deterministic=True)
         U_lqr = problem.F @ (goal - X_lqr)
 
         cost = jax.numpy.minimum(
             state.q1_apply(params,
                            observation=state.problem.unwrap_angles(X),
                            R=goal,
-                           action=U),
+                           action=normalized_U),
             state.q2_apply(params,
                            observation=state.problem.unwrap_angles(X),
                            R=goal,
-                           action=U))
+                           action=normalized_U))
+
+        U = normalized_U * problem.action_limit
 
         U_plot = data.U.at[i, :].set(U)
         U_lqr_plot = data.U_lqr.at[i, :].set(U_lqr)
@@ -200,8 +203,9 @@
         X = problem.A @ X + problem.B @ U
         X_lqr = problem.A @ X_lqr + problem.B @ U_lqr
 
-        reward = data.reward - state.problem.cost(X, U, goal)
-        reward_lqr = data.reward_lqr - state.problem.cost(X_lqr, U_lqr, goal)
+        reward = data.reward + state.problem.reward(X, normalized_U, goal)
+        reward_lqr = data.reward_lqr + state.problem.reward(
+            X_lqr, U_lqr / problem.action_limit, goal)
 
         return X, X_lqr, data._replace(
             t=t,
diff --git a/frc971/control_loops/swerve/velocity_controller/model.py b/frc971/control_loops/swerve/velocity_controller/model.py
index 1394463..97db4c5 100644
--- a/frc971/control_loops/swerve/velocity_controller/model.py
+++ b/frc971/control_loops/swerve/velocity_controller/model.py
@@ -14,9 +14,12 @@
 from jax.experimental import mesh_utils
 from jax.sharding import Mesh, PartitionSpec, NamedSharding
 from frc971.control_loops.swerve import jax_dynamics
-from frc971.control_loops.swerve import dynamics
 from frc971.control_loops.swerve.velocity_controller import physics
 from frc971.control_loops.swerve.velocity_controller import experience_buffer
+from tensorflow_probability.substrates import jax as tfp
+
+tfd = tfp.distributions
+tfb = tfp.bijectors
 
 from flax.typing import PRNGKey
 
@@ -127,36 +130,25 @@
         if rng is None:
             rng = self.make_rng('pi')
 
-        # Grab a random sample
-        random_sample = jax.random.normal(rng, shape=std.shape)
+        pi_distribution = tfd.TransformedDistribution(
+            distribution=tfd.Normal(loc=mu, scale=std),
+            bijector=tfb.Tanh(),
+        )
 
         if deterministic:
             # We are testing the optimal policy, just use the mean.
-            pi_action = mu
+            pi_action = flax.linen.activation.tanh(mu)
         else:
-            # Use the reparameterization trick.  Adjust the unit gausian with
-            # something we can solve for to get the desired noise.
-            pi_action = random_sample * std + mu
+            pi_action = pi_distribution.sample(shape=(1, ), seed=rng)
 
-        logp_pi = gaussian_likelihood(random_sample, log_std)
-        # Adjustment to log prob
-        # NOTE: This formula is a little bit magic. To get an understanding of where it
-        # comes from, check out the original SAC paper (arXiv 1801.01290) and look in
-        # appendix C. This is a more numerically-stable equivalent to Eq 21.
-        delta = (2.0 * (jax.numpy.log(2.0) - pi_action -
-                        flax.linen.softplus(-2.0 * pi_action)))
+        logp_pi = pi_distribution.log_prob(pi_action)
 
-        if len(delta.shape) > 1:
-            delta = jax.numpy.sum(delta, keepdims=True, axis=1)
+        if len(logp_pi.shape) > 1:
+            logp_pi = jax.numpy.sum(logp_pi, keepdims=True, axis=1)
         else:
-            delta = jax.numpy.sum(delta, keepdims=True)
+            logp_pi = jax.numpy.sum(logp_pi, keepdims=True)
 
-        logp_pi = logp_pi - delta
-
-        # Now, saturate the action to the limit using tanh
-        pi_action = self.action_limit * flax.linen.activation.tanh(pi_action)
-
-        return pi_action, logp_pi, self.action_limit * std, random_sample
+        return pi_action, logp_pi, self.action_limit * std
 
 
 class MLPQFunction(nn.Module):
diff --git a/frc971/control_loops/swerve/velocity_controller/physics.py b/frc971/control_loops/swerve/velocity_controller/physics.py
index b0da0b4..e70d871 100644
--- a/frc971/control_loops/swerve/velocity_controller/physics.py
+++ b/frc971/control_loops/swerve/velocity_controller/physics.py
@@ -55,13 +55,17 @@
     def random_states(self, rng: PRNGKey, dimensions=None):
         raise NotImplemented("random_states not implemented")
 
-    def random_actions(self, rng: PRNGKey, dimensions=None):
+    def random_actions(self,
+                       rng: PRNGKey,
+                       X: jax.typing.ArrayLike,
+                       goal: jax.typing.ArrayLike,
+                       dimensions=None):
         """Produces a uniformly random action in the action space."""
         return jax.random.uniform(
             rng,
             (dimensions or FLAGS.num_agents, self.num_outputs),
-            minval=-self.action_limit,
-            maxval=self.action_limit,
+            minval=-1.0,
+            maxval=1.0,
         )
 
     def random_goals(self, rng: PRNGKey, dimensions=None):
@@ -94,12 +98,13 @@
         A_continuous = jax.numpy.array([[0., 1.], [0., -36.85154548]])
         B_continuous = jax.numpy.array([[0.], [56.08534375]])
 
+        U = U * self.action_limit
         return A_continuous @ X + B_continuous @ U
 
-    def cost(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike,
-             goal: jax.typing.ArrayLike):
-        return (X - goal).T @ jax.numpy.array(
-            self.Q) @ (X - goal) + U.T @ jax.numpy.array(self.R) @ U
+    def reward(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike,
+               goal: jax.typing.ArrayLike):
+        return -(X - goal).T @ jax.numpy.array(
+            self.Q) @ (X - goal) - U.T @ jax.numpy.array(self.R) @ U
 
     def random_states(self, rng: PRNGKey, dimensions=None):
         rng1, rng2 = jax.random.split(rng)
diff --git a/frc971/control_loops/swerve/velocity_controller/train.py b/frc971/control_loops/swerve/velocity_controller/train.py
index cf54bc4..3bdc18c 100644
--- a/frc971/control_loops/swerve/velocity_controller/train.py
+++ b/frc971/control_loops/swerve/velocity_controller/train.py
@@ -123,10 +123,10 @@
     R = data['goals']
 
     # Compute the ending actions from the current network.
-    action2, logp_pi2, _, _ = state.pi_apply(rng=rng,
-                                             params=params,
-                                             observation=observations2,
-                                             R=R)
+    action2, logp_pi2, _ = state.pi_apply(rng=rng,
+                                          params=params,
+                                          observation=observations2,
+                                          R=R)
 
     # Compute target network Q values
     q1_pi_target = state.q1_apply(state.target_params,
@@ -176,10 +176,10 @@
     # TODO(austin): We've got differentiable policy and differentiable physics.  Can we use those here?  Have Q learn the future, not the current step?
 
     # Compute the action
-    pi, logp_pi, _, _ = state.pi_apply(rng=rng,
-                                       params=params,
-                                       observation=observations1,
-                                       R=R)
+    pi, logp_pi, _ = state.pi_apply(rng=rng,
+                                    params=params,
+                                    observation=observations1,
+                                    R=R)
     q1_pi = state.q1_apply(jax.lax.stop_gradient(params),
                            observation=observations1,
                            R=R,
@@ -304,16 +304,18 @@
 
         def true_fn(i):
             # We are at the beginning of the process, pick a random action.
-            return state.problem.random_actions(action_rng, FLAGS.num_agents)
+            return state.problem.random_actions(action_rng,
+                                                X=observation,
+                                                goal=R,
+                                                dimensions=FLAGS.num_agents)
 
         def false_fn(i):
             # We are past the beginning of the process, use the trained network.
-            pi_action, logp_pi, std, random_sample = state.pi_apply(
-                rng=action_rng,
-                params=state.params,
-                observation=observation,
-                R=R,
-                deterministic=False)
+            pi_action, _, _ = state.pi_apply(rng=action_rng,
+                                             params=state.params,
+                                             observation=observation,
+                                             R=R,
+                                             deterministic=False)
             return pi_action
 
         pi_action = jax.lax.cond(
@@ -331,8 +333,9 @@
         # Soft Actor-Critic is designed to maximize reward.  LQR minimizes
         # cost.  There is nothing which assumes anything about the sign of
         # the reward, so use the negative of the cost.
-        reward = -jax.vmap(state.problem.cost)(
-            X=observation2, U=pi_action, goal=R)
+        reward = jax.vmap(state.problem.reward)(X=observation2,
+                                                U=pi_action,
+                                                goal=R)
 
         replay_buffer_state = state.replay_buffer.add(
             replay_buffer_state, {