Fix goal not being passed to gradient from data
Signed-off-by: justinT21 <jjturcot@gmail.com>
Change-Id: I6a66c96b347f1cd23a04d66bac0f59a0ee12db36
diff --git a/frc971/control_loops/swerve/velocity_controller/train.py b/frc971/control_loops/swerve/velocity_controller/train.py
index b1ad9a8..cf54bc4 100644
--- a/frc971/control_loops/swerve/velocity_controller/train.py
+++ b/frc971/control_loops/swerve/velocity_controller/train.py
@@ -216,8 +216,9 @@
data: ArrayLike):
"""Computes the Soft Actor-Critic loss for alpha."""
observations1 = data['observations1']
+ R = data['goals']
pi, logp_pi, _, _ = jax.lax.stop_gradient(
- state.pi_apply(rng=rng, params=params, observation=observations1))
+ state.pi_apply(rng=rng, params=params, R=R, observation=observations1))
return (-jax.numpy.exp(params['logalpha']) *
(logp_pi + state.target_entropy)).mean()