Add standard deviation plot to lqr_plot

This helps visualize where it is certain vs uncertain

Change-Id: Ib1d9d506a4cb6e3fb59aaf5aa2ed0fc653673c42
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/velocity_controller/lqr_plot.py b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
index 01ebf47..a63e100 100644
--- a/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
+++ b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
@@ -47,9 +47,26 @@
 
 # Container for the data.
 Data = collections.namedtuple('Data', [
-    't', 'X', 'X_lqr', 'U', 'U_lqr', 'cost', 'cost_lqr', 'q1_grid', 'q2_grid',
-    'q_grid', 'target_q_grid', 'lqr_grid', 'pi_grid_U', 'lqr_grid_U', 'grid_X',
-    'grid_Y', 'reward', 'reward_lqr', 'step'
+    't',
+    'X',
+    'X_lqr',
+    'U',
+    'U_lqr',
+    'cost',
+    'cost_lqr',
+    'q1_grid',
+    'q2_grid',
+    'q_grid',
+    'target_q_grid',
+    'lqr_grid',
+    'pi_grid_U',
+    'lqr_grid_U',
+    'grid_X',
+    'grid_Y',
+    'reward',
+    'reward_lqr',
+    'step',
+    'stdev',
 ])
 
 FLAGS = absl.flags.FLAGS
@@ -168,6 +185,17 @@
                                  deterministic=True)
         return U[0] * problem.action_limit
 
+    def compute_pi_stdev(X, Y):
+        x = jax.numpy.array([X, Y])
+        _, _, std = state.pi_apply(rng,
+                                   state.params,
+                                   observation=state.problem.unwrap_angles(x),
+                                   R=goal,
+                                   deterministic=True)
+        return std[0]
+
+    std_grid = jax.vmap(jax.vmap(compute_pi_stdev))(grid_X, grid_Y)
+
     lqr_cost_U = jax.vmap(jax.vmap(compute_lqr_U))(grid_X, grid_Y)
     pi_cost_U = jax.vmap(jax.vmap(compute_pi_U))(grid_X, grid_Y)
 
@@ -251,6 +279,7 @@
             reward=0.0,
             reward_lqr=0.0,
             step=state.step,
+            stdev=std_grid,
         ), X, X_lqr, state.params)
 
     logging.info('Finished integrating, reward of %f, lqr reward of %f',
@@ -274,6 +303,7 @@
         lqr_grid_U=numpy.array(data.lqr_grid_U),
         grid_X=numpy.array(data.grid_X),
         grid_Y=numpy.array(data.grid_Y),
+        stdev=numpy.array(data.stdev),
         reward=float(data.reward),
         reward_lqr=float(data.reward_lqr),
         step=data.step,
@@ -326,9 +356,10 @@
 
         self.Ufig = pyplot.figure(figsize=pyplot.figaspect(0.5))
         self.Uax = [
-            self.Ufig.add_subplot(1, 3, 1, projection='3d'),
-            self.Ufig.add_subplot(1, 3, 2, projection='3d'),
-            self.Ufig.add_subplot(1, 3, 3, projection='3d'),
+            self.Ufig.add_subplot(2, 2, 1, projection='3d'),
+            self.Ufig.add_subplot(2, 2, 2, projection='3d'),
+            self.Ufig.add_subplot(2, 2, 3, projection='3d'),
+            self.Ufig.add_subplot(2, 2, 4, projection='3d'),
         ]
 
         self.last_trajectory_step = 0
@@ -419,6 +450,7 @@
             (data.lqr_grid_U, 'lqr'),
             (data.pi_grid_U, 'pi'),
             ((data.lqr_grid_U - data.pi_grid_U), 'error'),
+            (data.stdev, 'stdev'),
         ]
 
         self.Usurf = [