Add optional skip layer connections and rmsnorm
Some literature found these helpful, might as well make them easy to
try.
Change-Id: If8e11fd9eec0576405e0cd59457b504181b6fe3e
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/velocity_controller/model.py b/frc971/control_loops/swerve/velocity_controller/model.py
index acbac08..ebf47f7 100644
--- a/frc971/control_loops/swerve/velocity_controller/model.py
+++ b/frc971/control_loops/swerve/velocity_controller/model.py
@@ -73,6 +73,18 @@
help='Batch size for learning Q and pi',
)
+absl.flags.DEFINE_boolean(
+ 'skip_layer',
+ default=False,
+ help='If true, add skip layer connections to the Q network.',
+)
+
+absl.flags.DEFINE_boolean(
+ 'rmsnorm',
+ default=False,
+ help='If true, use rmsnorm instead of layer norm.',
+)
+
HIDDEN_WEIGHTS = 256
LOG_STD_MIN = -20
@@ -171,12 +183,20 @@
# Estimate Q with a simple multi layer dense network.
x = jax.numpy.hstack((observation, R, action))
for i, hidden_size in enumerate(self.hidden_sizes):
+ # Add d2rl skip layer connections if requested
+ if FLAGS.skip_layer and i != 0:
+ x = jax.numpy.hstack((x, observation, R, action))
+
x = nn.Dense(
name=f'denselayer{i}',
features=hidden_size,
)(x)
- # Layernorm also improves stability.
- x = nn.LayerNorm(name=f'layernorm{i}')(x)
+
+ if FLAGS.rmsnorm:
+ x = nn.RMSNorm(name=f'rmsnorm{i}')(x)
+ else:
+ # Layernorm also improves stability.
+ x = nn.LayerNorm(name=f'layernorm{i}')(x)
x = self.activation(x)
x = nn.Dense(