Add solver from dreamer v3
This apparently converges faster. It is handy to have other options to
try, so lets add it.
Change-Id: If7eb1192cdf939308300a04fb06aa758e8eee3b2
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/velocity_controller/model.py b/frc971/control_loops/swerve/velocity_controller/model.py
index 6aa3e47..f115bbb 100644
--- a/frc971/control_loops/swerve/velocity_controller/model.py
+++ b/frc971/control_loops/swerve/velocity_controller/model.py
@@ -92,6 +92,12 @@
help='If true, use rmsnorm instead of layer norm.',
)
+absl.flags.DEFINE_boolean(
+ 'dreamer_solver',
+ default=False,
+ help='If true, use the solver from dreamer v3 instead of adam.',
+)
+
HIDDEN_WEIGHTS = 256
LOG_STD_MIN = -20
@@ -450,9 +456,14 @@
'logalpha': logalpha,
}
- pi_tx = optax.adam(learning_rate=pi_learning_rate)
- q_tx = optax.adam(learning_rate=q_learning_rate)
- alpha_tx = optax.adam(learning_rate=alpha_learning_rate)
+ if FLAGS.dreamer_solver:
+ pi_tx = create_dreamer_solver(learning_rate=pi_learning_rate)
+ q_tx = create_dreamer_solver(learning_rate=q_learning_rate)
+ alpha_tx = create_dreamer_solver(learning_rate=alpha_learning_rate)
+ else:
+ pi_tx = optax.adam(learning_rate=pi_learning_rate)
+ q_tx = optax.adam(learning_rate=q_learning_rate)
+ alpha_tx = optax.adam(learning_rate=alpha_learning_rate)
result = TrainState.create(
problem=problem,
@@ -468,6 +479,89 @@
return result
+# Solver from dreamer v3.
+# TODO(austin): How many of these pieces are actually in optax already?
+def scale_by_rms(beta=0.999, eps=1e-8):
+
+ def init_fn(params):
+ nu = jax.tree_util.tree_map(
+ lambda t: jax.numpy.zeros_like(t, jax.numpy.float32), params)
+ step = jax.numpy.zeros((), jax.numpy.int32)
+ return (step, nu)
+
+ def update_fn(updates, state, params=None):
+ step, nu = state
+ step = optax.safe_int32_increment(step)
+ nu = jax.tree_util.tree_map(
+ lambda v, u: beta * v + (1 - beta) * (u * u), nu, updates)
+ nu_hat = optax.bias_correction(nu, beta, step)
+ updates = jax.tree_util.tree_map(
+ lambda u, v: u / (jax.numpy.sqrt(v) + eps), updates, nu_hat)
+ return updates, (step, nu)
+
+ return optax.GradientTransformation(init_fn, update_fn)
+
+
+def scale_by_agc(clip=0.03, pmin=1e-3):
+
+ def init_fn(params):
+ return ()
+
+ def update_fn(updates, state, params=None):
+
+ def fn(param, update):
+ unorm = jax.numpy.linalg.norm(update.flatten(), 2)
+ pnorm = jax.numpy.linalg.norm(param.flatten(), 2)
+ upper = clip * jax.numpy.maximum(pmin, pnorm)
+ return update * (1 / jax.numpy.maximum(1.0, unorm / upper))
+
+ updates = jax.tree_util.tree_map(fn, params, updates)
+ return updates, ()
+
+ return optax.GradientTransformation(init_fn, update_fn)
+
+
+def scale_by_momentum(beta=0.9, nesterov=False):
+
+ def init_fn(params):
+ mu = jax.tree_util.tree_map(
+ lambda t: jax.numpy.zeros_like(t, jax.numpy.float32), params)
+ step = jax.numpy.zeros((), jax.numpy.int32)
+ return (step, mu)
+
+ def update_fn(updates, state, params=None):
+ step, mu = state
+ step = optax.safe_int32_increment(step)
+ mu = optax.update_moment(updates, mu, beta, 1)
+ if nesterov:
+ mu_nesterov = optax.update_moment(updates, mu, beta, 1)
+ mu_hat = optax.bias_correction(mu_nesterov, beta, step)
+ else:
+ mu_hat = optax.bias_correction(mu, beta, step)
+ return mu_hat, (step, mu)
+
+ return optax.GradientTransformation(init_fn, update_fn)
+
+
+def create_dreamer_solver(
+ learning_rate,
+ agc: float = 0.3,
+ pmin: float = 1e-3,
+ beta1: float = 0.9,
+ beta2: float = 0.999,
+ eps: float = 1e-20,
+ nesterov: bool = False,
+) -> optax.base.GradientTransformation:
+ # From dreamer v3.
+ return optax.chain(
+ # Adaptive gradient clipping.
+ scale_by_agc(agc, pmin),
+ scale_by_rms(beta2, eps),
+ scale_by_momentum(beta1, nesterov),
+ optax.scale_by_learning_rate(learning_rate),
+ )
+
+
def create_learning_rate_fn(
base_learning_rate: float,
final_learning_rate: float,