Speed up writing experience significantly

The jax calculation to compute the reward needs to be JITed to be fast.
Add that, and add some logging to make it easy to measure.

Change-Id: I9e9c53cf6ac0180e05819ffaefc5834784fb6e93
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/experience_collector.py b/frc971/control_loops/swerve/experience_collector.py
index 591cba8..66b425f 100644
--- a/frc971/control_loops/swerve/experience_collector.py
+++ b/frc971/control_loops/swerve/experience_collector.py
@@ -20,6 +20,7 @@
 
 from absl import flags
 from absl import app
+from absl import logging
 import pickle
 import numpy
 from frc971.control_loops.swerve import dynamics
@@ -64,6 +65,8 @@
     R_goal[2, 0] = FLAGS.omega
 
     solution = mpc.solve(p=numpy.vstack((X_initial, R_goal)))
+    sys.stderr.flush()
+    sys.stdout.flush()
 
     # Solver doesn't solve for the last state.  So we get N-1 states back.
     experience = {
@@ -77,6 +80,7 @@
     if not FLAGS.quiet:
         print('x(0):', X_initial.transpose())
 
+    logging.info('Finished solving')
     X_prior = X_initial.squeeze()
     for j in range(mpc.N - 1):
         if not FLAGS.quiet:
@@ -86,14 +90,17 @@
         X_prior = mpc.unpack_x(solution, j + 1)
         experience['observations2'][j, :] = X_prior
         experience['actions'][j, :] = mpc.unpack_u(solution, j)
+        experience['goals'][j, :] = R_goal[:, 0]
+
+    logging.info('Finished all but reward')
+    for j in range(mpc.N - 1):
         experience['rewards'][j, :] = problem.reward(
             X=X_prior,
             U=mpc.unpack_u(solution, j),
             goal=R_goal[:, 0],
         )
-        experience['goals'][j, :] = R_goal[:, 0]
-        sys.stderr.flush()
-        sys.stdout.flush()
+    sys.stderr.flush()
+    sys.stdout.flush()
 
     return experience
 
@@ -178,9 +185,10 @@
     for i in range(FLAGS.num_solutions):
         rng, rng_init = jax.random.split(rng)
         experience = collect_experience(problem, mpc, rng_init)
+        logging.info('Solved problem %d', i)
 
         save_experience(problem, mpc, experience, i)
-        logging.info('Solved problem %d', i)
+        logging.info('Wrote problem %d', i)
 
 
 if __name__ == '__main__':
diff --git a/frc971/control_loops/swerve/jax_dynamics.py b/frc971/control_loops/swerve/jax_dynamics.py
index 19a48de..58d5fcf 100644
--- a/frc971/control_loops/swerve/jax_dynamics.py
+++ b/frc971/control_loops/swerve/jax_dynamics.py
@@ -354,6 +354,7 @@
     ])
 
 
+@jax.jit
 def mpc_cost(coefficients: CoefficientsType, X, U, goal):
     J = 0