Make maximum entropy term for Q optional

SAC with experience paper suggests that this term many not always be
helpful.  Make it a flag, on by default.

Change-Id: I7ab55ef0b83dea71cc3d43fb0235de98164a8d59
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/velocity_controller/train.py b/frc971/control_loops/swerve/velocity_controller/train.py
index 721f18c..a9b9078 100644
--- a/frc971/control_loops/swerve/velocity_controller/train.py
+++ b/frc971/control_loops/swerve/velocity_controller/train.py
@@ -76,6 +76,13 @@
     help='If true, explode on any NaNs found, and print them.',
 )
 
+absl.flags.DEFINE_bool(
+    'maximum_entropy_q',
+    default=True,
+    help=
+    'If false, do not add the maximum entropy term to the bellman backup for Q.',
+)
+
 
 def save_checkpoint(state: TrainState, workdir: str):
     """Saves a checkpoint in the workdir."""
@@ -142,8 +149,12 @@
     alpha = jax.numpy.exp(params['logalpha'])
 
     # Now we can compute the Bellman backup
-    bellman_backup = jax.lax.stop_gradient(rewards + FLAGS.gamma *
-                                           (q_pi_target - alpha * logp_pi2))
+    if FLAGS.maximum_entropy_q:
+        bellman_backup = jax.lax.stop_gradient(
+            rewards + FLAGS.gamma * (q_pi_target - alpha * logp_pi2))
+    else:
+        bellman_backup = jax.lax.stop_gradient(rewards +
+                                               FLAGS.gamma * q_pi_target)
 
     # Compute the starting Q values from the Q network being optimized.
     q1 = state.q1_apply(params, observation=observations1, R=R, action=actions)