Add papers we used to the code
Signed-off-by: justinT21 <jjturcot@gmail.com>
Change-Id: I9a8e5a700f411f0100071161f2fc4491faeb7ad0
diff --git a/frc971/control_loops/swerve/velocity_controller/model.py b/frc971/control_loops/swerve/velocity_controller/model.py
index b8ec123..7f01eef 100644
--- a/frc971/control_loops/swerve/velocity_controller/model.py
+++ b/frc971/control_loops/swerve/velocity_controller/model.py
@@ -202,7 +202,8 @@
# Estimate Q with a simple multi layer dense network.
x = jax.numpy.hstack((observation, R, action))
for i, hidden_size in enumerate(self.hidden_sizes):
- # Add d2rl skip layer connections if requested
+ # Add d2rl skip layer connections if requested.
+ # Idea from D2RL: https://arxiv.org/pdf/2010.09163.
if FLAGS.skip_layer and i != 0:
x = jax.numpy.hstack((x, observation, R, action))
@@ -212,9 +213,11 @@
)(x)
if FLAGS.rmsnorm:
+ # Idea from Dreamerv3: https://arxiv.org/pdf/2301.04104v2.
x = nn.RMSNorm(name=f'rmsnorm{i}')(x)
else:
# Layernorm also improves stability.
+ # Idea from RLPD: https://arxiv.org/pdf/2302.02948.
x = nn.LayerNorm(name=f'layernorm{i}')(x)
x = self.activation(x)
@@ -425,6 +428,7 @@
action_space=problem.num_outputs,
action_limit=problem.action_limit)
# We want q1 and q2 to have different network architectures so they pick up differnet things.
+ # SiLu is used in DreamerV3 so we use it: https://arxiv.org/pdf/2301.04104v2.
q1 = MLPQFunction(activation=nn.activation.silu, hidden_sizes=[128, 256])
q2 = MLPQFunction(activation=nn.activation.silu, hidden_sizes=[256, 128])
@@ -485,7 +489,7 @@
return result
-# Solver from dreamer v3.
+# Solver from dreamer v3: https://arxiv.org/pdf/2301.04104v2.
# TODO(austin): How many of these pieces are actually in optax already?
def scale_by_rms(beta=0.999, eps=1e-8):
diff --git a/frc971/control_loops/swerve/velocity_controller/train.py b/frc971/control_loops/swerve/velocity_controller/train.py
index 9c5abe9..34d904a 100644
--- a/frc971/control_loops/swerve/velocity_controller/train.py
+++ b/frc971/control_loops/swerve/velocity_controller/train.py
@@ -1,3 +1,6 @@
+# Machine learning based on Soft Actor Critic(SAC) which was initially proposed in https://arxiv.org/pdf/1801.01290.
+# Our implementation was heavily based on OpenAI's spinning up reference implementation https://spinningup.openai.com/en/latest/algorithms/sac.html.
+
import absl
import time
import collections
@@ -149,6 +152,7 @@
alpha = jax.numpy.exp(params['logalpha'])
# Now we can compute the Bellman backup
+ # Max entropy SAC is based on https://arxiv.org/pdf/1812.05905.
if FLAGS.maximum_entropy_q:
bellman_backup = jax.lax.stop_gradient(
rewards + FLAGS.gamma * (q_pi_target - alpha * logp_pi2))
@@ -340,9 +344,6 @@
lambda o, pi: state.problem.integrate_dynamics(o, pi),
in_axes=(0, 0))(observation, pi_action)
- # Soft Actor-Critic is designed to maximize reward. LQR minimizes
- # cost. There is nothing which assumes anything about the sign of
- # the reward, so use the negative of the cost.
reward = jax.vmap(state.problem.reward)(X=observation2,
U=pi_action,
goal=R)