Add optional skip layer connections and rmsnorm

Some literature found these helpful, might as well make them easy to
try.

Change-Id: If8e11fd9eec0576405e0cd59457b504181b6fe3e
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/velocity_controller/model.py b/frc971/control_loops/swerve/velocity_controller/model.py
index acbac08..ebf47f7 100644
--- a/frc971/control_loops/swerve/velocity_controller/model.py
+++ b/frc971/control_loops/swerve/velocity_controller/model.py
@@ -73,6 +73,18 @@
     help='Batch size for learning Q and pi',
 )
 
+absl.flags.DEFINE_boolean(
+    'skip_layer',
+    default=False,
+    help='If true, add skip layer connections to the Q network.',
+)
+
+absl.flags.DEFINE_boolean(
+    'rmsnorm',
+    default=False,
+    help='If true, use rmsnorm instead of layer norm.',
+)
+
 HIDDEN_WEIGHTS = 256
 
 LOG_STD_MIN = -20
@@ -171,12 +183,20 @@
         # Estimate Q with a simple multi layer dense network.
         x = jax.numpy.hstack((observation, R, action))
         for i, hidden_size in enumerate(self.hidden_sizes):
+            # Add d2rl skip layer connections if requested
+            if FLAGS.skip_layer and i != 0:
+                x = jax.numpy.hstack((x, observation, R, action))
+
             x = nn.Dense(
                 name=f'denselayer{i}',
                 features=hidden_size,
             )(x)
-            # Layernorm also improves stability.
-            x = nn.LayerNorm(name=f'layernorm{i}')(x)
+
+            if FLAGS.rmsnorm:
+                x = nn.RMSNorm(name=f'rmsnorm{i}')(x)
+            else:
+                # Layernorm also improves stability.
+                x = nn.LayerNorm(name=f'layernorm{i}')(x)
             x = self.activation(x)
 
         x = nn.Dense(