Speed up writing experience significantly

The jax calculation to compute the reward needs to be JITed to be fast.
Add that, and add some logging to make it easy to measure.

Change-Id: I9e9c53cf6ac0180e05819ffaefc5834784fb6e93
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/jax_dynamics.py b/frc971/control_loops/swerve/jax_dynamics.py
index 19a48de..58d5fcf 100644
--- a/frc971/control_loops/swerve/jax_dynamics.py
+++ b/frc971/control_loops/swerve/jax_dynamics.py
@@ -354,6 +354,7 @@
     ])
 
 
+@jax.jit
 def mpc_cost(coefficients: CoefficientsType, X, U, goal):
     J = 0