Reorganize batched training functions in train.py

By interleaving them with the actual logic, they were making it harder
to see the logic.

Change-Id: Ic1a45bec3db2d23406301713d40c076148a3937d
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/velocity_controller/train.py b/frc971/control_loops/swerve/velocity_controller/train.py
index 2c3f56c..f7a59c2 100644
--- a/frc971/control_loops/swerve/velocity_controller/train.py
+++ b/frc971/control_loops/swerve/velocity_controller/train.py
@@ -156,19 +156,6 @@
 
 
 @jax.jit
-def compute_batched_loss_q(state: TrainState, rng: PRNGKey, params,
-                           data: ArrayLike):
-
-    def bound_compute_loss_q(rng, data):
-        return compute_loss_q(state, rng, params, data)
-
-    return jax.vmap(bound_compute_loss_q)(
-        jax.random.split(rng, FLAGS.num_agents),
-        data,
-    ).mean()
-
-
-@jax.jit
 def compute_loss_pi(state: TrainState, rng: PRNGKey, params, data: ArrayLike):
     """Computes the Soft Actor-Critic loss for pi."""
     observations1 = data['observations1']
@@ -199,6 +186,32 @@
 
 
 @jax.jit
+def compute_loss_alpha(state: TrainState, rng: PRNGKey, params,
+                       data: ArrayLike):
+    """Computes the Soft Actor-Critic loss for alpha."""
+    observations1 = data['observations1']
+    R = data['goals']
+    pi, logp_pi, _ = jax.lax.stop_gradient(
+        state.pi_apply(rng=rng, params=params, observation=observations1, R=R))
+
+    return (-jax.numpy.exp(params['logalpha']) *
+            (logp_pi + state.target_entropy)).mean(), logp_pi.mean()
+
+
+@jax.jit
+def compute_batched_loss_q(state: TrainState, rng: PRNGKey, params,
+                           data: ArrayLike):
+
+    def bound_compute_loss_q(rng, data):
+        return compute_loss_q(state, rng, params, data)
+
+    return jax.vmap(bound_compute_loss_q)(
+        jax.random.split(rng, FLAGS.num_agents),
+        data,
+    ).mean()
+
+
+@jax.jit
 def compute_batched_loss_pi(state: TrainState, rng: PRNGKey, params,
                             data: ArrayLike):
 
@@ -212,19 +225,6 @@
 
 
 @jax.jit
-def compute_loss_alpha(state: TrainState, rng: PRNGKey, params,
-                       data: ArrayLike):
-    """Computes the Soft Actor-Critic loss for alpha."""
-    observations1 = data['observations1']
-    R = data['goals']
-    pi, logp_pi, _, _ = jax.lax.stop_gradient(
-        state.pi_apply(rng=rng, params=params, R=R, observation=observations1))
-
-    return (-jax.numpy.exp(params['logalpha']) *
-            (logp_pi + state.target_entropy)).mean()
-
-
-@jax.jit
 def compute_batched_loss_alpha(state: TrainState, rng: PRNGKey, params,
                                data: ArrayLike):