Add solver from dreamer v3

This apparently converges faster.  It is handy to have other options to
try, so lets add it.

Change-Id: If7eb1192cdf939308300a04fb06aa758e8eee3b2
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/velocity_controller/model.py b/frc971/control_loops/swerve/velocity_controller/model.py
index 6aa3e47..f115bbb 100644
--- a/frc971/control_loops/swerve/velocity_controller/model.py
+++ b/frc971/control_loops/swerve/velocity_controller/model.py
@@ -92,6 +92,12 @@
     help='If true, use rmsnorm instead of layer norm.',
 )
 
+absl.flags.DEFINE_boolean(
+    'dreamer_solver',
+    default=False,
+    help='If true, use the solver from dreamer v3 instead of adam.',
+)
+
 HIDDEN_WEIGHTS = 256
 
 LOG_STD_MIN = -20
@@ -450,9 +456,14 @@
             'logalpha': logalpha,
         }
 
-    pi_tx = optax.adam(learning_rate=pi_learning_rate)
-    q_tx = optax.adam(learning_rate=q_learning_rate)
-    alpha_tx = optax.adam(learning_rate=alpha_learning_rate)
+    if FLAGS.dreamer_solver:
+        pi_tx = create_dreamer_solver(learning_rate=pi_learning_rate)
+        q_tx = create_dreamer_solver(learning_rate=q_learning_rate)
+        alpha_tx = create_dreamer_solver(learning_rate=alpha_learning_rate)
+    else:
+        pi_tx = optax.adam(learning_rate=pi_learning_rate)
+        q_tx = optax.adam(learning_rate=q_learning_rate)
+        alpha_tx = optax.adam(learning_rate=alpha_learning_rate)
 
     result = TrainState.create(
         problem=problem,
@@ -468,6 +479,89 @@
     return result
 
 
+# Solver from dreamer v3.
+# TODO(austin): How many of these pieces are actually in optax already?
+def scale_by_rms(beta=0.999, eps=1e-8):
+
+    def init_fn(params):
+        nu = jax.tree_util.tree_map(
+            lambda t: jax.numpy.zeros_like(t, jax.numpy.float32), params)
+        step = jax.numpy.zeros((), jax.numpy.int32)
+        return (step, nu)
+
+    def update_fn(updates, state, params=None):
+        step, nu = state
+        step = optax.safe_int32_increment(step)
+        nu = jax.tree_util.tree_map(
+            lambda v, u: beta * v + (1 - beta) * (u * u), nu, updates)
+        nu_hat = optax.bias_correction(nu, beta, step)
+        updates = jax.tree_util.tree_map(
+            lambda u, v: u / (jax.numpy.sqrt(v) + eps), updates, nu_hat)
+        return updates, (step, nu)
+
+    return optax.GradientTransformation(init_fn, update_fn)
+
+
+def scale_by_agc(clip=0.03, pmin=1e-3):
+
+    def init_fn(params):
+        return ()
+
+    def update_fn(updates, state, params=None):
+
+        def fn(param, update):
+            unorm = jax.numpy.linalg.norm(update.flatten(), 2)
+            pnorm = jax.numpy.linalg.norm(param.flatten(), 2)
+            upper = clip * jax.numpy.maximum(pmin, pnorm)
+            return update * (1 / jax.numpy.maximum(1.0, unorm / upper))
+
+        updates = jax.tree_util.tree_map(fn, params, updates)
+        return updates, ()
+
+    return optax.GradientTransformation(init_fn, update_fn)
+
+
+def scale_by_momentum(beta=0.9, nesterov=False):
+
+    def init_fn(params):
+        mu = jax.tree_util.tree_map(
+            lambda t: jax.numpy.zeros_like(t, jax.numpy.float32), params)
+        step = jax.numpy.zeros((), jax.numpy.int32)
+        return (step, mu)
+
+    def update_fn(updates, state, params=None):
+        step, mu = state
+        step = optax.safe_int32_increment(step)
+        mu = optax.update_moment(updates, mu, beta, 1)
+        if nesterov:
+            mu_nesterov = optax.update_moment(updates, mu, beta, 1)
+            mu_hat = optax.bias_correction(mu_nesterov, beta, step)
+        else:
+            mu_hat = optax.bias_correction(mu, beta, step)
+        return mu_hat, (step, mu)
+
+    return optax.GradientTransformation(init_fn, update_fn)
+
+
+def create_dreamer_solver(
+    learning_rate,
+    agc: float = 0.3,
+    pmin: float = 1e-3,
+    beta1: float = 0.9,
+    beta2: float = 0.999,
+    eps: float = 1e-20,
+    nesterov: bool = False,
+) -> optax.base.GradientTransformation:
+    # From dreamer v3.
+    return optax.chain(
+        # Adaptive gradient clipping.
+        scale_by_agc(agc, pmin),
+        scale_by_rms(beta2, eps),
+        scale_by_momentum(beta1, nesterov),
+        optax.scale_by_learning_rate(learning_rate),
+    )
+
+
 def create_learning_rate_fn(
     base_learning_rate: float,
     final_learning_rate: float,