Speed up writing experience significantly
The jax calculation to compute the reward needs to be JITed to be fast.
Add that, and add some logging to make it easy to measure.
Change-Id: I9e9c53cf6ac0180e05819ffaefc5834784fb6e93
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/experience_collector.py b/frc971/control_loops/swerve/experience_collector.py
index 591cba8..66b425f 100644
--- a/frc971/control_loops/swerve/experience_collector.py
+++ b/frc971/control_loops/swerve/experience_collector.py
@@ -20,6 +20,7 @@
from absl import flags
from absl import app
+from absl import logging
import pickle
import numpy
from frc971.control_loops.swerve import dynamics
@@ -64,6 +65,8 @@
R_goal[2, 0] = FLAGS.omega
solution = mpc.solve(p=numpy.vstack((X_initial, R_goal)))
+ sys.stderr.flush()
+ sys.stdout.flush()
# Solver doesn't solve for the last state. So we get N-1 states back.
experience = {
@@ -77,6 +80,7 @@
if not FLAGS.quiet:
print('x(0):', X_initial.transpose())
+ logging.info('Finished solving')
X_prior = X_initial.squeeze()
for j in range(mpc.N - 1):
if not FLAGS.quiet:
@@ -86,14 +90,17 @@
X_prior = mpc.unpack_x(solution, j + 1)
experience['observations2'][j, :] = X_prior
experience['actions'][j, :] = mpc.unpack_u(solution, j)
+ experience['goals'][j, :] = R_goal[:, 0]
+
+ logging.info('Finished all but reward')
+ for j in range(mpc.N - 1):
experience['rewards'][j, :] = problem.reward(
X=X_prior,
U=mpc.unpack_u(solution, j),
goal=R_goal[:, 0],
)
- experience['goals'][j, :] = R_goal[:, 0]
- sys.stderr.flush()
- sys.stdout.flush()
+ sys.stderr.flush()
+ sys.stdout.flush()
return experience
@@ -178,9 +185,10 @@
for i in range(FLAGS.num_solutions):
rng, rng_init = jax.random.split(rng)
experience = collect_experience(problem, mpc, rng_init)
+ logging.info('Solved problem %d', i)
save_experience(problem, mpc, experience, i)
- logging.info('Solved problem %d', i)
+ logging.info('Wrote problem %d', i)
if __name__ == '__main__':
diff --git a/frc971/control_loops/swerve/jax_dynamics.py b/frc971/control_loops/swerve/jax_dynamics.py
index 19a48de..58d5fcf 100644
--- a/frc971/control_loops/swerve/jax_dynamics.py
+++ b/frc971/control_loops/swerve/jax_dynamics.py
@@ -354,6 +354,7 @@
])
+@jax.jit
def mpc_cost(coefficients: CoefficientsType, X, U, goal):
J = 0