Switch solvers and expose alpha's learning rate
adam seems to work better, and we want to update entropy faster than we
want to update the other parameters. So split it out separately.
Change-Id: Idba5686955c32d635f15be2b588bf764e2f1031c
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/velocity_controller/lqr_plot.py b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
index 2db9842..2bc65ff 100644
--- a/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
+++ b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
@@ -87,8 +87,13 @@
rng = jax.random.key(0)
rng, init_rng = jax.random.split(rng)
- state = create_train_state(init_rng, problem, FLAGS.q_learning_rate,
- FLAGS.pi_learning_rate)
+ state = create_train_state(
+ init_rng,
+ problem,
+ q_learning_rate=FLAGS.q_learning_rate,
+ pi_learning_rate=FLAGS.pi_learning_rate,
+ alpha_learning_rate=FLAGS.alpha_learning_rate,
+ )
state = restore_checkpoint(state, FLAGS.workdir)
if step is not None and state.step == step:
diff --git a/frc971/control_loops/swerve/velocity_controller/model.py b/frc971/control_loops/swerve/velocity_controller/model.py
index 97db4c5..7d62b0e 100644
--- a/frc971/control_loops/swerve/velocity_controller/model.py
+++ b/frc971/control_loops/swerve/velocity_controller/model.py
@@ -32,6 +32,12 @@
)
absl.flags.DEFINE_float(
+ 'alpha_learning_rate',
+ default=0.004,
+ help='Training learning rate for entropy.',
+)
+
+absl.flags.DEFINE_float(
'q_learning_rate',
default=0.002,
help='Training learning rate.',
@@ -372,7 +378,7 @@
def create_train_state(rng: PRNGKey, problem: Problem, q_learning_rate,
- pi_learning_rate):
+ pi_learning_rate, alpha_learning_rate):
"""Creates initial `TrainState`."""
pi = SquashedGaussianMLPActor(activation=nn.activation.gelu,
action_space=problem.num_outputs,
@@ -415,9 +421,9 @@
'logalpha': logalpha,
}
- pi_tx = optax.sgd(learning_rate=pi_learning_rate)
- q_tx = optax.sgd(learning_rate=q_learning_rate)
- alpha_tx = optax.sgd(learning_rate=q_learning_rate)
+ pi_tx = optax.adam(learning_rate=pi_learning_rate)
+ q_tx = optax.adam(learning_rate=q_learning_rate)
+ alpha_tx = optax.adam(learning_rate=alpha_learning_rate)
result = TrainState.create(
problem=problem,
diff --git a/frc971/control_loops/swerve/velocity_controller/train.py b/frc971/control_loops/swerve/velocity_controller/train.py
index 3bdc18c..ed0cdeb 100644
--- a/frc971/control_loops/swerve/velocity_controller/train.py
+++ b/frc971/control_loops/swerve/velocity_controller/train.py
@@ -399,6 +399,7 @@
run['hparams'] = {
'q_learning_rate': FLAGS.q_learning_rate,
'pi_learning_rate': FLAGS.pi_learning_rate,
+ 'alpha_learning_rate': FLAGS.alpha_learning_rate,
'batch_size': FLAGS.batch_size,
'horizon': FLAGS.horizon,
'warmup_steps': FLAGS.warmup_steps,
@@ -425,6 +426,7 @@
problem,
q_learning_rate=q_learning_rate,
pi_learning_rate=pi_learning_rate,
+ alpha_learning_rate=FLAGS.alpha_learning_rate,
)
state = restore_checkpoint(state, workdir)