Add a parallelized experience collector using the MPC

This sets us up to collect experience from the MPC to use as training
data and help us converge to the actual MPC solution.  We want to try
RLPD instead of just SAC.

Change-Id: Ia54e26b80f16e2e2a92284b152b09304b566d3b6
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/BUILD b/frc971/control_loops/swerve/BUILD
index 2c3ad50..7a8c1ea 100644
--- a/frc971/control_loops/swerve/BUILD
+++ b/frc971/control_loops/swerve/BUILD
@@ -292,6 +292,33 @@
     ],
 )
 
+py_binary(
+    name = "experience_collector",
+    srcs = [
+        "experience_collector.py",
+    ],
+    deps = [
+        ":casadi_velocity_mpc_lib",
+        ":jax_dynamics",
+        "//frc971/control_loops/swerve/velocity_controller:physics",
+        "@pip//absl_py",
+        "@pip//matplotlib",
+        "@pip//numpy",
+        "@pip//pygobject",
+        "@pip//scipy",
+        "@pip//tensorflow",
+    ],
+)
+
+py_binary(
+    name = "multi_experience_collector",
+    srcs = ["multi_experience_collector.py"],
+    data = [":experience_collector"],
+    deps = [
+        "@pip//absl_py",
+    ],
+)
+
 py_library(
     name = "physics_test_utils",
     srcs = [
diff --git a/frc971/control_loops/swerve/casadi_velocity_mpc.py b/frc971/control_loops/swerve/casadi_velocity_mpc.py
index d62f7ed..54e33f1 100644
--- a/frc971/control_loops/swerve/casadi_velocity_mpc.py
+++ b/frc971/control_loops/swerve/casadi_velocity_mpc.py
@@ -24,11 +24,6 @@
 flags.DEFINE_bool('pickle', False, 'Write optimization results.')
 flags.DEFINE_string('outputdir', None, 'Directory to write problem results to')
 
-# Full print level on ipopt. Piping to a file and using a filter or search method is suggested
-# grad_x prints out the gradient at each iteration in the following sequence: U0, X1, U1, etc.
-flags.DEFINE_bool('full_debug', False,
-                  'If true, turn on all the debugging in the solver.')
-
 
 class Solver(object):
 
diff --git a/frc971/control_loops/swerve/casadi_velocity_mpc_lib.py b/frc971/control_loops/swerve/casadi_velocity_mpc_lib.py
index f464268..e358422 100644
--- a/frc971/control_loops/swerve/casadi_velocity_mpc_lib.py
+++ b/frc971/control_loops/swerve/casadi_velocity_mpc_lib.py
@@ -3,11 +3,19 @@
 from frc971.control_loops.swerve import dynamics
 import casadi
 import numpy
+from absl import flags
+
+FLAGS = flags.FLAGS
+
+# Full print level on ipopt. Piping to a file and using a filter or search method is suggested
+# grad_x prints out the gradient at each iteration in the following sequence: U0, X1, U1, etc.
+flags.DEFINE_bool('full_debug', False,
+                  'If true, turn on all the debugging in the solver.')
 
 
 class MPC(object):
 
-    def __init__(self, solver='fatrop', jit=True):
+    def __init__(self, solver='fatrop', jit=True, N=200):
         self.fdot = dynamics.swerve_full_dynamics(
             casadi.SX.sym("X", dynamics.NUM_STATES, 1),
             casadi.SX.sym("U", 8, 1))
@@ -47,7 +55,7 @@
         self.next_X = self.make_physics()
         self.cost = self.make_cost()
 
-        self.N = 200
+        self.N = N
 
         # Start with an empty nonlinear program.
         self.w = []
diff --git a/frc971/control_loops/swerve/experience_collector.py b/frc971/control_loops/swerve/experience_collector.py
new file mode 100644
index 0000000..591cba8
--- /dev/null
+++ b/frc971/control_loops/swerve/experience_collector.py
@@ -0,0 +1,187 @@
+#!/usr/bin/env python3
+import os, sys
+
+# Setup XLA first.
+os.environ['XLA_FLAGS'] = ' '.join([
+    # Teach it where to find CUDA
+    '--xla_gpu_cuda_data_dir=/usr/lib/cuda',
+    # Use up to 20 cores
+    #'--xla_force_host_platform_device_count=6',
+    # Dump XLA to /tmp/foo to aid debugging
+    #'--xla_dump_to=/tmp/foo',
+    #'--xla_gpu_enable_command_buffer='
+    # Dump sharding
+    #"--xla_dump_to=/tmp/foo",
+    #"--xla_dump_hlo_pass_re=spmd|propagation"
+])
+os.environ['JAX_PLATFORMS'] = 'cpu'
+os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
+os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.50'
+
+from absl import flags
+from absl import app
+import pickle
+import numpy
+from frc971.control_loops.swerve import dynamics
+from frc971.control_loops.swerve.casadi_velocity_mpc_lib import MPC
+import jax
+import tensorflow as tf
+from frc971.control_loops.swerve.velocity_controller.physics import SwerveProblem
+from frc971.control_loops.swerve import jax_dynamics
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_bool('compileonly', False,
+                  'If true, load casadi, don\'t compile it')
+
+flags.DEFINE_float('vx', 1.0, 'Goal velocity in m/s in x')
+flags.DEFINE_float('vy', 0.0, 'Goal velocity in m/s in y')
+flags.DEFINE_float('omega', 0.0, 'Goal velocity in m/s in omega')
+flags.DEFINE_integer('seed', 0, 'Seed for random initial state.')
+
+flags.DEFINE_bool('save_plots', True,
+                  'If true, save plots for each run as well.')
+flags.DEFINE_string('outputdir', None, 'Directory to write problem results to')
+flags.DEFINE_bool('quiet', False, 'If true, print a lot less')
+
+flags.DEFINE_integer('num_solutions', 100,
+                     'Number of random problems to solve.')
+flags.DEFINE_integer('horizon', 200, 'Horizon to solve for')
+
+try:
+    from matplotlib import pylab
+except ModuleNotFoundError:
+    pass
+
+
+def collect_experience(problem, mpc, rng):
+    X_initial = numpy.array(problem.random_states(rng,
+                                                  dimensions=1)).transpose()
+
+    R_goal = numpy.zeros((3, 1))
+    R_goal[0, 0] = FLAGS.vx
+    R_goal[1, 0] = FLAGS.vy
+    R_goal[2, 0] = FLAGS.omega
+
+    solution = mpc.solve(p=numpy.vstack((X_initial, R_goal)))
+
+    # Solver doesn't solve for the last state.  So we get N-1 states back.
+    experience = {
+        'observations1': numpy.zeros((mpc.N - 1, problem.num_states)),
+        'observations2': numpy.zeros((mpc.N - 1, problem.num_states)),
+        'actions': numpy.zeros((mpc.N - 1, problem.num_outputs)),
+        'rewards': numpy.zeros((mpc.N - 1, 1)),
+        'goals': numpy.zeros((mpc.N - 1, problem.num_goals)),
+    }
+
+    if not FLAGS.quiet:
+        print('x(0):', X_initial.transpose())
+
+    X_prior = X_initial.squeeze()
+    for j in range(mpc.N - 1):
+        if not FLAGS.quiet:
+            print(f'u({j}): ', mpc.unpack_u(solution, j))
+            print(f'x({j+1}): ', mpc.unpack_x(solution, j + 1))
+        experience['observations1'][j, :] = X_prior
+        X_prior = mpc.unpack_x(solution, j + 1)
+        experience['observations2'][j, :] = X_prior
+        experience['actions'][j, :] = mpc.unpack_u(solution, j)
+        experience['rewards'][j, :] = problem.reward(
+            X=X_prior,
+            U=mpc.unpack_u(solution, j),
+            goal=R_goal[:, 0],
+        )
+        experience['goals'][j, :] = R_goal[:, 0]
+        sys.stderr.flush()
+        sys.stdout.flush()
+
+    return experience
+
+
+def save_experience(problem, mpc, experience, experience_number):
+    with open(f'experience_{experience_number}.pkl', 'wb') as f:
+        pickle.dump(experience, f)
+
+    if not FLAGS.save_plots:
+        return
+
+    fig0, axs0 = pylab.subplots(3)
+    fig1, axs1 = pylab.subplots(2)
+
+    axs0[0].clear()
+    axs0[1].clear()
+    axs0[2].clear()
+
+    t = problem.dt * numpy.array(list(range(mpc.N - 1)))
+
+    X_plot = experience['observations1']
+    U_plot = experience['actions']
+
+    axs0[0].plot(t, X_plot[:, dynamics.VELOCITY_STATE_VX], label="vx")
+    axs0[0].plot(t, X_plot[:, dynamics.VELOCITY_STATE_VY], label="vy")
+    axs0[0].plot(t, X_plot[:, dynamics.VELOCITY_STATE_OMEGA], label="omega")
+    axs0[0].legend()
+
+    axs0[1].plot(t, U_plot[:, 0], label="Is0")
+    axs0[1].plot(t, U_plot[:, 2], label="Is1")
+    axs0[1].plot(t, U_plot[:, 4], label="Is2")
+    axs0[1].plot(t, U_plot[:, 6], label="Is3")
+    axs0[1].legend()
+
+    axs0[2].plot(t, U_plot[:, 1], label="Id0")
+    axs0[2].plot(t, U_plot[:, 3], label="Id1")
+    axs0[2].plot(t, U_plot[:, 5], label="Id2")
+    axs0[2].plot(t, U_plot[:, 7], label="Id3")
+    axs0[2].legend()
+
+    axs1[0].clear()
+    axs1[1].clear()
+
+    axs1[0].plot(t, X_plot[:, dynamics.VELOCITY_STATE_THETAS0], label='steer0')
+    axs1[0].plot(t, X_plot[:, dynamics.VELOCITY_STATE_THETAS1], label='steer1')
+    axs1[0].plot(t, X_plot[:, dynamics.VELOCITY_STATE_THETAS2], label='steer2')
+    axs1[0].plot(t, X_plot[:, dynamics.VELOCITY_STATE_THETAS3], label='steer3')
+    axs1[0].legend()
+    axs1[1].plot(t,
+                 X_plot[:, dynamics.VELOCITY_STATE_OMEGAS0],
+                 label='steer_velocity0')
+    axs1[1].plot(t,
+                 X_plot[:, dynamics.VELOCITY_STATE_OMEGAS1],
+                 label='steer_velocity1')
+    axs1[1].plot(t,
+                 X_plot[:, dynamics.VELOCITY_STATE_OMEGAS2],
+                 label='steer_velocity2')
+    axs1[1].plot(t,
+                 X_plot[:, dynamics.VELOCITY_STATE_OMEGAS3],
+                 label='steer_velocity3')
+    axs1[1].legend()
+
+    fig0.savefig(f'state_{experience_number}.svg')
+    fig1.savefig(f'steer_{experience_number}.svg')
+
+
+def main(argv):
+    if FLAGS.outputdir:
+        os.chdir(FLAGS.outputdir)
+
+    # Hide any GPUs from TensorFlow. Otherwise it might reserve memory.
+    tf.config.experimental.set_visible_devices([], 'GPU')
+    rng = jax.random.key(FLAGS.seed)
+
+    physics_constants = jax_dynamics.Coefficients()
+    problem = SwerveProblem(physics_constants)
+    mpc = MPC(solver='ipopt', N=(FLAGS.horizon + 1))
+
+    if FLAGS.compileonly:
+        return
+
+    for i in range(FLAGS.num_solutions):
+        rng, rng_init = jax.random.split(rng)
+        experience = collect_experience(problem, mpc, rng_init)
+
+        save_experience(problem, mpc, experience, i)
+        logging.info('Solved problem %d', i)
+
+
+if __name__ == '__main__':
+    app.run(main)
diff --git a/frc971/control_loops/swerve/multi_experience_collector.py b/frc971/control_loops/swerve/multi_experience_collector.py
new file mode 100644
index 0000000..d0551e1
--- /dev/null
+++ b/frc971/control_loops/swerve/multi_experience_collector.py
@@ -0,0 +1,54 @@
+#!/usr/bin/env python3
+from absl import app
+from absl import flags
+import sys
+from multiprocessing.pool import ThreadPool
+import pathlib
+import subprocess
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string('outdir', '/tmp/swerve', "Directory to write results to.")
+flags.DEFINE_integer('num_actors', 20, 'Number of actors to run in parallel.')
+flags.DEFINE_integer('num_solutions', 100,
+                     'Number of random problems to solve.')
+
+
+def collect_experience(agent_number):
+    filename = f'{agent_number}'
+    if FLAGS.outdir:
+        subdir = pathlib.Path(FLAGS.outdir) / filename
+    else:
+        subdir = pathlib.Path(filename)
+    subdir.mkdir(parents=True, exist_ok=True)
+
+    with open(f'{subdir.resolve()}/log', 'w') as output:
+        subprocess.check_call(
+            args=[
+                sys.executable,
+                "frc971/control_loops/swerve/experience_collector",
+                f"--seed={agent_number}",
+                f"--outputdir={subdir.resolve()}",
+                "--quiet",
+                f"--num_solutions={FLAGS.num_solutions}",
+            ],
+            stdout=output,
+            stderr=output,
+        )
+
+
+def main(argv):
+    # Load a simple problem first so we compile with less system load.  This
+    # makes it faster on a processor with frequency boosting.
+    subprocess.check_call(args=[
+        sys.executable,
+        "frc971/control_loops/swerve/experience_collector",
+        "--compileonly",
+    ])
+
+    # Try a bunch of goals now
+    with ThreadPool(FLAGS.num_actors) as pool:
+        pool.starmap(collect_experience, zip(range(FLAGS.num_actors), ))
+
+
+if __name__ == '__main__':
+    app.run(main)