Reorganize batched training functions in train.py
By interleaving them with the actual logic, they were making it harder
to see the logic.
Change-Id: Ic1a45bec3db2d23406301713d40c076148a3937d
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/velocity_controller/train.py b/frc971/control_loops/swerve/velocity_controller/train.py
index 2c3f56c..f7a59c2 100644
--- a/frc971/control_loops/swerve/velocity_controller/train.py
+++ b/frc971/control_loops/swerve/velocity_controller/train.py
@@ -156,19 +156,6 @@
@jax.jit
-def compute_batched_loss_q(state: TrainState, rng: PRNGKey, params,
- data: ArrayLike):
-
- def bound_compute_loss_q(rng, data):
- return compute_loss_q(state, rng, params, data)
-
- return jax.vmap(bound_compute_loss_q)(
- jax.random.split(rng, FLAGS.num_agents),
- data,
- ).mean()
-
-
-@jax.jit
def compute_loss_pi(state: TrainState, rng: PRNGKey, params, data: ArrayLike):
"""Computes the Soft Actor-Critic loss for pi."""
observations1 = data['observations1']
@@ -199,6 +186,32 @@
@jax.jit
+def compute_loss_alpha(state: TrainState, rng: PRNGKey, params,
+ data: ArrayLike):
+ """Computes the Soft Actor-Critic loss for alpha."""
+ observations1 = data['observations1']
+ R = data['goals']
+ pi, logp_pi, _ = jax.lax.stop_gradient(
+ state.pi_apply(rng=rng, params=params, observation=observations1, R=R))
+
+ return (-jax.numpy.exp(params['logalpha']) *
+ (logp_pi + state.target_entropy)).mean(), logp_pi.mean()
+
+
+@jax.jit
+def compute_batched_loss_q(state: TrainState, rng: PRNGKey, params,
+ data: ArrayLike):
+
+ def bound_compute_loss_q(rng, data):
+ return compute_loss_q(state, rng, params, data)
+
+ return jax.vmap(bound_compute_loss_q)(
+ jax.random.split(rng, FLAGS.num_agents),
+ data,
+ ).mean()
+
+
+@jax.jit
def compute_batched_loss_pi(state: TrainState, rng: PRNGKey, params,
data: ArrayLike):
@@ -212,19 +225,6 @@
@jax.jit
-def compute_loss_alpha(state: TrainState, rng: PRNGKey, params,
- data: ArrayLike):
- """Computes the Soft Actor-Critic loss for alpha."""
- observations1 = data['observations1']
- R = data['goals']
- pi, logp_pi, _, _ = jax.lax.stop_gradient(
- state.pi_apply(rng=rng, params=params, R=R, observation=observations1))
-
- return (-jax.numpy.exp(params['logalpha']) *
- (logp_pi + state.target_entropy)).mean()
-
-
-@jax.jit
def compute_batched_loss_alpha(state: TrainState, rng: PRNGKey, params,
data: ArrayLike):