Add standard deviation plot to lqr_plot
This helps visualize where it is certain vs uncertain
Change-Id: Ib1d9d506a4cb6e3fb59aaf5aa2ed0fc653673c42
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/velocity_controller/lqr_plot.py b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
index 01ebf47..a63e100 100644
--- a/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
+++ b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
@@ -47,9 +47,26 @@
# Container for the data.
Data = collections.namedtuple('Data', [
- 't', 'X', 'X_lqr', 'U', 'U_lqr', 'cost', 'cost_lqr', 'q1_grid', 'q2_grid',
- 'q_grid', 'target_q_grid', 'lqr_grid', 'pi_grid_U', 'lqr_grid_U', 'grid_X',
- 'grid_Y', 'reward', 'reward_lqr', 'step'
+ 't',
+ 'X',
+ 'X_lqr',
+ 'U',
+ 'U_lqr',
+ 'cost',
+ 'cost_lqr',
+ 'q1_grid',
+ 'q2_grid',
+ 'q_grid',
+ 'target_q_grid',
+ 'lqr_grid',
+ 'pi_grid_U',
+ 'lqr_grid_U',
+ 'grid_X',
+ 'grid_Y',
+ 'reward',
+ 'reward_lqr',
+ 'step',
+ 'stdev',
])
FLAGS = absl.flags.FLAGS
@@ -168,6 +185,17 @@
deterministic=True)
return U[0] * problem.action_limit
+ def compute_pi_stdev(X, Y):
+ x = jax.numpy.array([X, Y])
+ _, _, std = state.pi_apply(rng,
+ state.params,
+ observation=state.problem.unwrap_angles(x),
+ R=goal,
+ deterministic=True)
+ return std[0]
+
+ std_grid = jax.vmap(jax.vmap(compute_pi_stdev))(grid_X, grid_Y)
+
lqr_cost_U = jax.vmap(jax.vmap(compute_lqr_U))(grid_X, grid_Y)
pi_cost_U = jax.vmap(jax.vmap(compute_pi_U))(grid_X, grid_Y)
@@ -251,6 +279,7 @@
reward=0.0,
reward_lqr=0.0,
step=state.step,
+ stdev=std_grid,
), X, X_lqr, state.params)
logging.info('Finished integrating, reward of %f, lqr reward of %f',
@@ -274,6 +303,7 @@
lqr_grid_U=numpy.array(data.lqr_grid_U),
grid_X=numpy.array(data.grid_X),
grid_Y=numpy.array(data.grid_Y),
+ stdev=numpy.array(data.stdev),
reward=float(data.reward),
reward_lqr=float(data.reward_lqr),
step=data.step,
@@ -326,9 +356,10 @@
self.Ufig = pyplot.figure(figsize=pyplot.figaspect(0.5))
self.Uax = [
- self.Ufig.add_subplot(1, 3, 1, projection='3d'),
- self.Ufig.add_subplot(1, 3, 2, projection='3d'),
- self.Ufig.add_subplot(1, 3, 3, projection='3d'),
+ self.Ufig.add_subplot(2, 2, 1, projection='3d'),
+ self.Ufig.add_subplot(2, 2, 2, projection='3d'),
+ self.Ufig.add_subplot(2, 2, 3, projection='3d'),
+ self.Ufig.add_subplot(2, 2, 4, projection='3d'),
]
self.last_trajectory_step = 0
@@ -419,6 +450,7 @@
(data.lqr_grid_U, 'lqr'),
(data.pi_grid_U, 'pi'),
((data.lqr_grid_U - data.pi_grid_U), 'error'),
+ (data.stdev, 'stdev'),
]
self.Usurf = [