Make maximum entropy term for Q optional
SAC with experience paper suggests that this term many not always be
helpful. Make it a flag, on by default.
Change-Id: I7ab55ef0b83dea71cc3d43fb0235de98164a8d59
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/velocity_controller/train.py b/frc971/control_loops/swerve/velocity_controller/train.py
index 721f18c..a9b9078 100644
--- a/frc971/control_loops/swerve/velocity_controller/train.py
+++ b/frc971/control_loops/swerve/velocity_controller/train.py
@@ -76,6 +76,13 @@
help='If true, explode on any NaNs found, and print them.',
)
+absl.flags.DEFINE_bool(
+ 'maximum_entropy_q',
+ default=True,
+ help=
+ 'If false, do not add the maximum entropy term to the bellman backup for Q.',
+)
+
def save_checkpoint(state: TrainState, workdir: str):
"""Saves a checkpoint in the workdir."""
@@ -142,8 +149,12 @@
alpha = jax.numpy.exp(params['logalpha'])
# Now we can compute the Bellman backup
- bellman_backup = jax.lax.stop_gradient(rewards + FLAGS.gamma *
- (q_pi_target - alpha * logp_pi2))
+ if FLAGS.maximum_entropy_q:
+ bellman_backup = jax.lax.stop_gradient(
+ rewards + FLAGS.gamma * (q_pi_target - alpha * logp_pi2))
+ else:
+ bellman_backup = jax.lax.stop_gradient(rewards +
+ FLAGS.gamma * q_pi_target)
# Compute the starting Q values from the Q network being optimized.
q1 = state.q1_apply(params, observation=observations1, R=R, action=actions)