Log entropy as well when training
This helps us see how effective tuning alpha is.
Change-Id: I94a3b2d87d18a76744a41dbec0725cc8b1b31f42
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/velocity_controller/train.py b/frc971/control_loops/swerve/velocity_controller/train.py
index f7a59c2..618ab28 100644
--- a/frc971/control_loops/swerve/velocity_controller/train.py
+++ b/frc971/control_loops/swerve/velocity_controller/train.py
@@ -231,14 +231,15 @@
def bound_compute_loss_alpha(rng, data):
return compute_loss_alpha(state, rng, params, data)
- return jax.vmap(bound_compute_loss_alpha)(
+ loss, entropy = jax.vmap(bound_compute_loss_alpha)(
jax.random.split(rng, FLAGS.num_agents),
data,
- ).mean()
+ )
+ return (loss.mean(), entropy.mean())
@jax.jit
-def train_step(state: TrainState, data, action_data, update_rng: PRNGKey,
+def train_step(state: TrainState, data, update_rng: PRNGKey,
step: int) -> TrainState:
"""Updates the parameters for Q, Pi, target Q, and alpha."""
update_rng, q_grad_rng = jax.random.split(update_rng)
@@ -253,16 +254,11 @@
state = state.q_apply_gradients(step=step, grads=q_grads)
- update_rng, pi_grad_rng = jax.random.split(update_rng)
-
# Update pi
+ update_rng, pi_grad_rng = jax.random.split(update_rng)
pi_grad_fn = jax.value_and_grad(lambda params: compute_batched_loss_pi(
- state, pi_grad_rng, params, action_data))
+ state, pi_grad_rng, params, data))
pi_loss, pi_grads = pi_grad_fn(state.params)
-
- print_nan(step, pi_loss)
- print_nan(step, pi_grads)
-
state = state.pi_apply_gradients(step=step, grads=pi_grads)
update_rng, alpha_grad_rng = jax.random.split(update_rng)
@@ -271,15 +267,18 @@
# Update alpha
alpha_grad_fn = jax.value_and_grad(
lambda params: compute_batched_loss_alpha(state, alpha_grad_rng,
- params, data))
- alpha_loss, alpha_grads = alpha_grad_fn(state.params)
+ params, data),
+ has_aux=True,
+ )
+ (alpha_loss, entropy), alpha_grads = alpha_grad_fn(state.params)
print_nan(step, alpha_loss)
print_nan(step, alpha_grads)
state = state.alpha_apply_gradients(step=step, grads=alpha_grads)
else:
+ entropy = 0.0
alpha_loss = 0.0
- return state, q_loss, pi_loss, alpha_loss
+ return state, q_loss, pi_loss, alpha_loss, entropy
@jax.jit
@@ -360,10 +359,8 @@
step: int):
rng, sample_rng = jax.random.split(rng)
- action_data = state.replay_buffer.sample(replay_buffer_state, sample_rng)
-
def update_iteration(i, val):
- rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, action_data = val
+ rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, entropy = val
rng, sample_rng, update_rng = jax.random.split(rng, 3)
batch = state.replay_buffer.sample(replay_buffer_state, sample_rng)
@@ -371,22 +368,18 @@
print_nan(i, replay_buffer_state)
print_nan(i, batch)
- state, q_loss, pi_loss, alpha_loss = train_step(
- state,
- data=batch.experience,
- action_data=batch.experience,
- update_rng=update_rng,
- step=i)
+ state, q_loss, pi_loss, alpha_loss, entropy = train_step(
+ state, data=batch.experience, update_rng=update_rng, step=i)
- return rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, action_data
+ return rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, entropy
- rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, action_data = jax.lax.fori_loop(
+ rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, entropy = jax.lax.fori_loop(
step, step + FLAGS.horizon + 1, update_iteration,
- (rng, state, 0.0, 0.0, 0.0, replay_buffer_state, action_data))
+ (rng, state, 0.0, 0.0, 0.0, replay_buffer_state, 0))
state = state.target_apply_gradients(step=state.step)
- return rng, state, q_loss, pi_loss, alpha_loss
+ return rng, state, q_loss, pi_loss, alpha_loss, entropy
def train(workdir: str, problem: Problem) -> train_state.TrainState:
@@ -467,24 +460,25 @@
)
def nop(rng, state, replay_buffer_state, step):
- return rng, state.update_step(step=step), 0.0, 0.0, 0.0
+ return rng, state.update_step(step=step), 0.0, 0.0, 0.0, 0.0
# Train
- rng, state, q_loss, pi_loss, alpha_loss = jax.lax.cond(
+ rng, state, q_loss, pi_loss, alpha_loss, entropy = jax.lax.cond(
step >= update_after, update_gradients, nop, rng, state,
replay_buffer_state, step)
- return state, replay_buffer_state, rng, q_loss, pi_loss, alpha_loss
+ return state, replay_buffer_state, rng, q_loss, pi_loss, alpha_loss, entropy
+ last_time = time.time()
for step in range(0, FLAGS.steps, FLAGS.horizon):
- state, replay_buffer_state, rng, q_loss, pi_loss, alpha_loss = train_loop(
+ state, replay_buffer_state, rng, q_loss, pi_loss, alpha_loss, entropy = train_loop(
state, replay_buffer_state, rng, step)
if FLAGS.debug_nan and has_nan(state.params):
logging.fatal('Nan params, aborting')
logging.info(
- 'Step %s: q_loss=%s, pi_loss=%s, alpha_loss=%s, q_learning_rate=%s, pi_learning_rate=%s, alpha=%s',
+ 'Step %s: q_loss=%s, pi_loss=%s, alpha_loss=%s, q_learning_rate=%s, pi_learning_rate=%s, alpha=%s, entropy=%s, random=%s',
step,
q_loss,
pi_loss,
@@ -492,6 +486,8 @@
q_learning_rate(step),
pi_learning_rate(step),
jax.numpy.exp(state.params['logalpha']),
+ entropy,
+ step <= FLAGS.random_sample_steps,
)
run.track(
@@ -499,12 +495,14 @@
'q_loss': float(q_loss),
'pi_loss': float(pi_loss),
'alpha_loss': float(alpha_loss),
- 'alpha': float(jax.numpy.exp(state.params['logalpha']))
+ 'alpha': float(jax.numpy.exp(state.params['logalpha'])),
+ 'entropy': entropy,
},
step=step)
- if step % 1000 == 0 and step > update_after:
+ if time.time() > last_time + 3.0 and step > update_after:
# TODO(austin): Simulate a rollout and accumulate the reward. How good are we doing?
save_checkpoint(state, workdir)
+ last_time = time.time()
return state