Log entropy as well when training

This helps us see how effective tuning alpha is.

Change-Id: I94a3b2d87d18a76744a41dbec0725cc8b1b31f42
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/velocity_controller/train.py b/frc971/control_loops/swerve/velocity_controller/train.py
index f7a59c2..618ab28 100644
--- a/frc971/control_loops/swerve/velocity_controller/train.py
+++ b/frc971/control_loops/swerve/velocity_controller/train.py
@@ -231,14 +231,15 @@
     def bound_compute_loss_alpha(rng, data):
         return compute_loss_alpha(state, rng, params, data)
 
-    return jax.vmap(bound_compute_loss_alpha)(
+    loss, entropy = jax.vmap(bound_compute_loss_alpha)(
         jax.random.split(rng, FLAGS.num_agents),
         data,
-    ).mean()
+    )
+    return (loss.mean(), entropy.mean())
 
 
 @jax.jit
-def train_step(state: TrainState, data, action_data, update_rng: PRNGKey,
+def train_step(state: TrainState, data, update_rng: PRNGKey,
                step: int) -> TrainState:
     """Updates the parameters for Q, Pi, target Q, and alpha."""
     update_rng, q_grad_rng = jax.random.split(update_rng)
@@ -253,16 +254,11 @@
 
     state = state.q_apply_gradients(step=step, grads=q_grads)
 
-    update_rng, pi_grad_rng = jax.random.split(update_rng)
-
     # Update pi
+    update_rng, pi_grad_rng = jax.random.split(update_rng)
     pi_grad_fn = jax.value_and_grad(lambda params: compute_batched_loss_pi(
-        state, pi_grad_rng, params, action_data))
+        state, pi_grad_rng, params, data))
     pi_loss, pi_grads = pi_grad_fn(state.params)
-
-    print_nan(step, pi_loss)
-    print_nan(step, pi_grads)
-
     state = state.pi_apply_gradients(step=step, grads=pi_grads)
 
     update_rng, alpha_grad_rng = jax.random.split(update_rng)
@@ -271,15 +267,18 @@
         # Update alpha
         alpha_grad_fn = jax.value_and_grad(
             lambda params: compute_batched_loss_alpha(state, alpha_grad_rng,
-                                                      params, data))
-        alpha_loss, alpha_grads = alpha_grad_fn(state.params)
+                                                      params, data),
+            has_aux=True,
+        )
+        (alpha_loss, entropy), alpha_grads = alpha_grad_fn(state.params)
         print_nan(step, alpha_loss)
         print_nan(step, alpha_grads)
         state = state.alpha_apply_gradients(step=step, grads=alpha_grads)
     else:
+        entropy = 0.0
         alpha_loss = 0.0
 
-    return state, q_loss, pi_loss, alpha_loss
+    return state, q_loss, pi_loss, alpha_loss, entropy
 
 
 @jax.jit
@@ -360,10 +359,8 @@
                      step: int):
     rng, sample_rng = jax.random.split(rng)
 
-    action_data = state.replay_buffer.sample(replay_buffer_state, sample_rng)
-
     def update_iteration(i, val):
-        rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, action_data = val
+        rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, entropy = val
         rng, sample_rng, update_rng = jax.random.split(rng, 3)
 
         batch = state.replay_buffer.sample(replay_buffer_state, sample_rng)
@@ -371,22 +368,18 @@
         print_nan(i, replay_buffer_state)
         print_nan(i, batch)
 
-        state, q_loss, pi_loss, alpha_loss = train_step(
-            state,
-            data=batch.experience,
-            action_data=batch.experience,
-            update_rng=update_rng,
-            step=i)
+        state, q_loss, pi_loss, alpha_loss, entropy = train_step(
+            state, data=batch.experience, update_rng=update_rng, step=i)
 
-        return rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, action_data
+        return rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, entropy
 
-    rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, action_data = jax.lax.fori_loop(
+    rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, entropy = jax.lax.fori_loop(
         step, step + FLAGS.horizon + 1, update_iteration,
-        (rng, state, 0.0, 0.0, 0.0, replay_buffer_state, action_data))
+        (rng, state, 0.0, 0.0, 0.0, replay_buffer_state, 0))
 
     state = state.target_apply_gradients(step=state.step)
 
-    return rng, state, q_loss, pi_loss, alpha_loss
+    return rng, state, q_loss, pi_loss, alpha_loss, entropy
 
 
 def train(workdir: str, problem: Problem) -> train_state.TrainState:
@@ -467,24 +460,25 @@
         )
 
         def nop(rng, state, replay_buffer_state, step):
-            return rng, state.update_step(step=step), 0.0, 0.0, 0.0
+            return rng, state.update_step(step=step), 0.0, 0.0, 0.0, 0.0
 
         # Train
-        rng, state, q_loss, pi_loss, alpha_loss = jax.lax.cond(
+        rng, state, q_loss, pi_loss, alpha_loss, entropy = jax.lax.cond(
             step >= update_after, update_gradients, nop, rng, state,
             replay_buffer_state, step)
 
-        return state, replay_buffer_state, rng, q_loss, pi_loss, alpha_loss
+        return state, replay_buffer_state, rng, q_loss, pi_loss, alpha_loss, entropy
 
+    last_time = time.time()
     for step in range(0, FLAGS.steps, FLAGS.horizon):
-        state, replay_buffer_state, rng, q_loss, pi_loss, alpha_loss = train_loop(
+        state, replay_buffer_state, rng, q_loss, pi_loss, alpha_loss, entropy = train_loop(
             state, replay_buffer_state, rng, step)
 
         if FLAGS.debug_nan and has_nan(state.params):
             logging.fatal('Nan params, aborting')
 
         logging.info(
-            'Step %s: q_loss=%s, pi_loss=%s, alpha_loss=%s, q_learning_rate=%s, pi_learning_rate=%s, alpha=%s',
+            'Step %s: q_loss=%s, pi_loss=%s, alpha_loss=%s, q_learning_rate=%s, pi_learning_rate=%s, alpha=%s, entropy=%s, random=%s',
             step,
             q_loss,
             pi_loss,
@@ -492,6 +486,8 @@
             q_learning_rate(step),
             pi_learning_rate(step),
             jax.numpy.exp(state.params['logalpha']),
+            entropy,
+            step <= FLAGS.random_sample_steps,
         )
 
         run.track(
@@ -499,12 +495,14 @@
                 'q_loss': float(q_loss),
                 'pi_loss': float(pi_loss),
                 'alpha_loss': float(alpha_loss),
-                'alpha': float(jax.numpy.exp(state.params['logalpha']))
+                'alpha': float(jax.numpy.exp(state.params['logalpha'])),
+                'entropy': entropy,
             },
             step=step)
 
-        if step % 1000 == 0 and step > update_after:
+        if time.time() > last_time + 3.0 and step > update_after:
             # TODO(austin): Simulate a rollout and accumulate the reward.  How good are we doing?
             save_checkpoint(state, workdir)
+            last_time = time.time()
 
     return state