Add code to train a simple turret controller
This uses Soft Actor-Critic with automatic temperature adjustment to
train a controller for last year's turret. This is a stepping step to
training a more complicated turret for the swerve. I need to still
parallelize it, play with the hyper-parameters, and teach it how to go
to a nonzero goal.
Change-Id: I1357b5fbf8549acac4ee0b94ef8f2636867c28ad
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/physics_test.py b/frc971/control_loops/swerve/physics_test.py
index 6e7ddf4..a45bf46 100644
--- a/frc971/control_loops/swerve/physics_test.py
+++ b/frc971/control_loops/swerve/physics_test.py
@@ -59,6 +59,7 @@
Xdot = self.position_swerve_full_dynamics(X, U)
Xdot_jax = jax_dynamics.full_dynamics(self.coefficients, X[:, 0], U[:,
0])
+ self.assertEqual(Xdot.shape[0], Xdot_jax.shape[0])
self.assertLess(
numpy.linalg.norm(Xdot[:, 0] - Xdot_jax),
@@ -71,6 +72,9 @@
velocity_physics_jax = jax_dynamics.velocity_dynamics(
self.coefficients, X_velocity[:, 0], U[:, 0])
+ self.assertEqual(velocity_physics.shape[0],
+ velocity_physics_jax.shape[0])
+
self.assertLess(
numpy.linalg.norm(velocity_physics[:, 0] - velocity_physics_jax),
2e-8,
diff --git a/frc971/control_loops/swerve/velocity_controller/BUILD b/frc971/control_loops/swerve/velocity_controller/BUILD
index 5e06420..f0e199b 100644
--- a/frc971/control_loops/swerve/velocity_controller/BUILD
+++ b/frc971/control_loops/swerve/velocity_controller/BUILD
@@ -22,3 +22,49 @@
"@pip//jax",
],
)
+
+py_binary(
+ name = "main",
+ srcs = [
+ "main.py",
+ "model.py",
+ "physics.py",
+ "train.py",
+ ],
+ deps = [
+ ":experience_buffer",
+ "//frc971/control_loops/swerve:jax_dynamics",
+ "@pip//absl_py",
+ "@pip//aim",
+ "@pip//clu",
+ "@pip//flashbax",
+ "@pip//flax",
+ "@pip//jax",
+ "@pip//jaxtyping",
+ "@pip//matplotlib",
+ "@pip//numpy",
+ "@pip//tensorflow",
+ ],
+)
+
+py_binary(
+ name = "lqr_plot",
+ srcs = [
+ "lqr_plot.py",
+ "model.py",
+ "physics.py",
+ ],
+ deps = [
+ ":experience_buffer",
+ "//frc971/control_loops/swerve:jax_dynamics",
+ "@pip//absl_py",
+ "@pip//flashbax",
+ "@pip//flax",
+ "@pip//jax",
+ "@pip//jaxtyping",
+ "@pip//matplotlib",
+ "@pip//numpy",
+ "@pip//pygobject",
+ "@pip//tensorflow",
+ ],
+)
diff --git a/frc971/control_loops/swerve/velocity_controller/lqr_plot.py b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
new file mode 100644
index 0000000..30f5630
--- /dev/null
+++ b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
@@ -0,0 +1,489 @@
+#!/usr/bin/env python3
+
+import os
+
+# Setup JAX to run on the CPU
+os.environ['XLA_FLAGS'] = ' '.join([
+ # Teach it where to find CUDA
+ '--xla_gpu_cuda_data_dir=/usr/lib/cuda',
+ # Use up to 20 cores
+ '--xla_force_host_platform_device_count=20',
+ # Dump XLA to /tmp/foo to aid debugging
+ '--xla_dump_to=/tmp/foo',
+ '--xla_gpu_enable_command_buffer='
+])
+
+os.environ['JAX_PLATFORMS'] = 'cpu'
+# Don't pre-allocate memory
+os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
+
+import absl
+from absl import logging
+from matplotlib.animation import FuncAnimation
+import matplotlib
+import numpy
+import scipy
+import time
+
+matplotlib.use("gtk3agg")
+
+from matplotlib import pylab
+from matplotlib import pyplot
+from flax.training import checkpoints
+from frc971.control_loops.python import controls
+import tensorflow as tf
+import threading
+import collections
+
+import jax
+
+jax._src.deprecations.accelerate('tracer-hash')
+jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
+jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
+jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
+
+from frc971.control_loops.swerve.velocity_controller.model import *
+from frc971.control_loops.swerve.velocity_controller.physics import *
+
+# Container for the data.
+Data = collections.namedtuple('Data', [
+ 't', 'X', 'X_lqr', 'U', 'U_lqr', 'cost', 'cost_lqr', 'q1_grid', 'q2_grid',
+ 'q_grid', 'target_q_grid', 'lqr_grid', 'pi_grid_U', 'lqr_grid_U', 'grid_X',
+ 'grid_Y', 'reward', 'reward_lqr', 'step'
+])
+
+FLAGS = absl.flags.FLAGS
+
+absl.flags.DEFINE_string('workdir', None, 'Directory to store model data.')
+
+absl.flags.DEFINE_integer(
+ 'horizon',
+ default=100,
+ help='Horizon to simulate',
+)
+
+absl.flags.DEFINE_float(
+ 'alpha',
+ default=0.2,
+ help='Entropy. If negative, automatically solve for it.',
+)
+
+numpy.set_printoptions(linewidth=200, )
+
+
+def restore_checkpoint(state: TrainState, workdir: str):
+ return checkpoints.restore_checkpoint(workdir, state)
+
+
+# Hard-coded simulation parameters for the turret.
+dt = 0.005
+A = numpy.matrix([[1., 0.00456639], [0., 0.83172142]])
+B = numpy.matrix([[0.00065992], [0.25610763]])
+
+Q = numpy.matrix([[2.77777778, 0.], [0., 0.01]])
+R = numpy.matrix([[0.00694444]])
+
+# Compute the optimal LQR cost + controller.
+F, P = controls.dlqr(A, B, Q, R, optimal_cost_function=True)
+
+
+def generate_data(step=None):
+ grid_X = numpy.arange(-1, 1, 0.1)
+ grid_Y = numpy.arange(-10, 10, 0.1)
+ grid_X, grid_Y = numpy.meshgrid(grid_X, grid_Y)
+ grid_X = jax.numpy.array(grid_X)
+ grid_Y = jax.numpy.array(grid_Y)
+ # Load the training state.
+ physics_constants = jax_dynamics.Coefficients()
+
+ rng = jax.random.key(0)
+ rng, init_rng = jax.random.split(rng)
+
+ state = create_train_state(init_rng, physics_constants,
+ FLAGS.q_learning_rate, FLAGS.pi_learning_rate)
+
+ state = restore_checkpoint(state, FLAGS.workdir)
+ if step is not None and state.step == step:
+ return None
+
+ print('F:', F)
+ print('P:', P)
+
+ X = jax.numpy.array([1.0, 0.0])
+ X_lqr = X.copy()
+ goal = jax.numpy.array([0.0, 0.0])
+
+ logging.info('X: %s', X)
+ logging.info('goal: %s', goal)
+ logging.debug('params: %s', state.params)
+
+ # Compute the various cost surfaces for plotting.
+ def compute_q1(X, Y):
+ return state.q1_apply(
+ state.params,
+ unwrap_angles(jax.numpy.array([X, Y])),
+ jax.numpy.array([0.]),
+ )[0]
+
+ def compute_q2(X, Y):
+ return state.q2_apply(
+ state.params,
+ unwrap_angles(jax.numpy.array([X, Y])),
+ jax.numpy.array([0.]),
+ )[0]
+
+ def lqr_cost(X, Y):
+ x = jax.numpy.array([X, Y])
+ return -x.T @ jax.numpy.array(P) @ x
+
+ def compute_q(params, x, y):
+ X = unwrap_angles(jax.numpy.array([x, y]))
+
+ return jax.numpy.minimum(
+ state.q1_apply(
+ params,
+ X,
+ jax.numpy.array([0.]),
+ )[0],
+ state.q2_apply(
+ params,
+ X,
+ jax.numpy.array([0.]),
+ )[0])
+
+ cost_grid1 = jax.vmap(jax.vmap(compute_q1))(grid_X, grid_Y)
+ cost_grid2 = jax.vmap(jax.vmap(compute_q2))(grid_X, grid_Y)
+ cost_grid = jax.vmap(jax.vmap(lambda x, y: compute_q(state.params, x, y)))(
+ grid_X, grid_Y)
+ target_cost_grid = jax.vmap(
+ jax.vmap(lambda x, y: compute_q(state.target_params, x, y)))(grid_X,
+ grid_Y)
+ lqr_cost_grid = jax.vmap(jax.vmap(lqr_cost))(grid_X, grid_Y)
+
+ # TODO(austin): Stuff to figure out:
+ # 3: Make it converge faster. Use both GPUs better?
+ # 4: Can we feed in a reference position and get it to learn to stabilize there?
+
+ # Now compute the two controller surfaces.
+ def compute_lqr_U(X, Y):
+ x = jax.numpy.array([X, Y])
+ return (-jax.numpy.array(F.reshape((2, ))) @ x)[0]
+
+ def compute_pi_U(X, Y):
+ x = jax.numpy.array([X, Y])
+ U, _, _, _ = state.pi_apply(rng,
+ state.params,
+ observation=unwrap_angles(x),
+ deterministic=True)
+ return U[0]
+
+ lqr_cost_U = jax.vmap(jax.vmap(compute_lqr_U))(grid_X, grid_Y)
+ pi_cost_U = jax.vmap(jax.vmap(compute_pi_U))(grid_X, grid_Y)
+
+ logging.info('Finished cost')
+
+ # Now simulate the robot, accumulating up things to plot.
+ def loop(i, val):
+ X, X_lqr, data, params = val
+ t = data.t.at[i].set(i * dt)
+
+ U, _, _, _ = state.pi_apply(rng,
+ params,
+ observation=unwrap_angles(X),
+ deterministic=True)
+ U_lqr = F @ (goal - X_lqr)
+
+ cost = jax.numpy.minimum(state.q1_apply(params, unwrap_angles(X), U),
+ state.q2_apply(params, unwrap_angles(X), U))
+
+ U_plot = data.U.at[i, :].set(U)
+ U_lqr_plot = data.U_lqr.at[i, :].set(U_lqr)
+ X_plot = data.X.at[i, :].set(X)
+ X_lqr_plot = data.X_lqr.at[i, :].set(X_lqr)
+ cost_plot = data.cost.at[i, :].set(cost)
+ cost_lqr_plot = data.cost_lqr.at[i, :].set(lqr_cost(*X_lqr))
+
+ X = A @ X + B @ U
+ X_lqr = A @ X_lqr + B @ U_lqr
+
+ reward = data.reward - physics.state_cost(X, U, goal)
+ reward_lqr = data.reward_lqr - physics.state_cost(X_lqr, U_lqr, goal)
+
+ return X, X_lqr, data._replace(
+ t=t,
+ U=U_plot,
+ U_lqr=U_lqr_plot,
+ X=X_plot,
+ X_lqr=X_lqr_plot,
+ cost=cost_plot,
+ cost_lqr=cost_lqr_plot,
+ reward=reward,
+ reward_lqr=reward_lqr,
+ ), params
+
+ # Do it.
+ @jax.jit
+ def integrate(data, X, X_lqr, params):
+ return jax.lax.fori_loop(0, FLAGS.horizon, loop,
+ (X, X_lqr, data, params))
+
+ X, X_lqr, data, params = integrate(
+ Data(
+ t=jax.numpy.zeros((FLAGS.horizon, )),
+ X=jax.numpy.zeros((FLAGS.horizon, physics.NUM_STATES)),
+ X_lqr=jax.numpy.zeros((FLAGS.horizon, physics.NUM_STATES)),
+ U=jax.numpy.zeros((FLAGS.horizon, physics.NUM_OUTPUTS)),
+ U_lqr=jax.numpy.zeros((FLAGS.horizon, physics.NUM_OUTPUTS)),
+ cost=jax.numpy.zeros((FLAGS.horizon, 1)),
+ cost_lqr=jax.numpy.zeros((FLAGS.horizon, 1)),
+ q1_grid=cost_grid1,
+ q2_grid=cost_grid2,
+ q_grid=cost_grid,
+ target_q_grid=target_cost_grid,
+ lqr_grid=lqr_cost_grid,
+ pi_grid_U=pi_cost_U,
+ lqr_grid_U=lqr_cost_U,
+ grid_X=grid_X,
+ grid_Y=grid_Y,
+ reward=0.0,
+ reward_lqr=0.0,
+ step=state.step,
+ ), X, X_lqr, state.params)
+
+ logging.info('Finished integrating, reward of %f, lqr reward of %f',
+ data.reward, data.reward_lqr)
+
+ # Convert back to numpy for plotting.
+ return Data(
+ t=numpy.array(data.t),
+ X=numpy.array(data.X),
+ X_lqr=numpy.array(data.X_lqr),
+ U=numpy.array(data.U),
+ U_lqr=numpy.array(data.U_lqr),
+ cost=numpy.array(data.cost),
+ cost_lqr=numpy.array(data.cost_lqr),
+ q1_grid=numpy.array(data.q1_grid),
+ q2_grid=numpy.array(data.q2_grid),
+ q_grid=numpy.array(data.q_grid),
+ target_q_grid=numpy.array(data.target_q_grid),
+ lqr_grid=numpy.array(data.lqr_grid),
+ pi_grid_U=numpy.array(data.pi_grid_U),
+ lqr_grid_U=numpy.array(data.lqr_grid_U),
+ grid_X=numpy.array(data.grid_X),
+ grid_Y=numpy.array(data.grid_Y),
+ reward=float(data.reward),
+ reward_lqr=float(data.reward_lqr),
+ step=data.step,
+ )
+
+
+class Plotter(object):
+
+ def __init__(self, data):
+ # Make all the plots and axis.
+ self.fig0, self.axs0 = pylab.subplots(3)
+ self.fig0.supxlabel('Seconds')
+
+ self.x, = self.axs0[0].plot([], [], label="x")
+ self.v, = self.axs0[0].plot([], [], label="v")
+ self.x_lqr, = self.axs0[0].plot([], [], label="x_lqr")
+ self.v_lqr, = self.axs0[0].plot([], [], label="v_lqr")
+
+ self.axs0[0].set_ylabel('Velocity')
+ self.axs0[0].legend()
+
+ self.uaxis, = self.axs0[1].plot([], [], label="U")
+ self.uaxis_lqr, = self.axs0[1].plot([], [], label="U_lqr")
+
+ self.axs0[1].set_ylabel('Amps')
+ self.axs0[1].legend()
+
+ self.costaxis, = self.axs0[2].plot([], [], label="cost")
+ self.costlqraxis, = self.axs0[2].plot([], [], label="cost lqr")
+ self.axs0[2].set_ylabel('Cost')
+ self.axs0[2].legend()
+
+ self.costfig = pyplot.figure(figsize=pyplot.figaspect(0.5))
+ self.cost3dax = [
+ self.costfig.add_subplot(2, 3, 1, projection='3d'),
+ self.costfig.add_subplot(2, 3, 2, projection='3d'),
+ self.costfig.add_subplot(2, 3, 3, projection='3d'),
+ self.costfig.add_subplot(2, 3, 4, projection='3d'),
+ self.costfig.add_subplot(2, 3, 5, projection='3d'),
+ self.costfig.add_subplot(2, 3, 6, projection='3d'),
+ ]
+
+ self.Ufig = pyplot.figure(figsize=pyplot.figaspect(0.5))
+ self.Uax = [
+ self.Ufig.add_subplot(1, 3, 1, projection='3d'),
+ self.Ufig.add_subplot(1, 3, 2, projection='3d'),
+ self.Ufig.add_subplot(1, 3, 3, projection='3d'),
+ ]
+
+ self.last_trajectory_step = 0
+ self.last_cost_step = 0
+ self.last_U_step = 0
+
+ def update_trajectory_plot(self, data):
+ if data.step == self.last_trajectory_step:
+ return
+ self.last_trajectory_step = data.step
+ logging.info('Updating trajectory plots')
+
+ # Put data in the trajectory plots.
+ self.x.set_data(data.t, data.X[:, 0])
+ self.v.set_data(data.t, data.X[:, 1])
+ self.x_lqr.set_data(data.t, data.X_lqr[:, 0])
+ self.v_lqr.set_data(data.t, data.X_lqr[:, 1])
+
+ self.uaxis.set_data(data.t, data.U[:, 0])
+ self.uaxis_lqr.set_data(data.t, data.U_lqr[:, 0])
+ self.costaxis.set_data(data.t, data.cost[:, 0] - data.cost[-1, 0])
+ self.costlqraxis.set_data(data.t, data.cost_lqr[:, 0])
+
+ self.axs0[0].relim()
+ self.axs0[0].autoscale_view()
+
+ self.axs0[1].relim()
+ self.axs0[1].autoscale_view()
+
+ self.axs0[2].relim()
+ self.axs0[2].autoscale_view()
+
+ return self.x, self.v, self.uaxis, self.costaxis, self.costlqraxis
+
+ def update_cost_plot(self, data):
+ if data.step == self.last_cost_step:
+ return
+ logging.info('Updating cost plots')
+ self.last_cost_step = data.step
+ # Put data in the cost plots.
+ if hasattr(self, 'costsurf'):
+ for surf in self.costsurf:
+ surf.remove()
+
+ plots = [
+ (data.q1_grid, 'q1'),
+ (data.q2_grid, 'q2'),
+ (data.q_grid, 'q'),
+ (data.target_q_grid, 'target q'),
+ (data.lqr_grid, 'lqr'),
+ (data.q_grid - data.q_grid.max() - data.lqr_grid, 'error'),
+ ]
+
+ self.costsurf = [
+ self.cost3dax[i].plot_surface(
+ data.grid_X,
+ data.grid_Y,
+ Z,
+ cmap="magma",
+ label=label,
+ ) for i, (Z, label) in enumerate(plots)
+ ]
+
+ for axis in self.cost3dax:
+ axis.legend()
+ axis.relim()
+ axis.autoscale_view()
+
+ return self.costsurf
+
+ def update_U_plot(self, data):
+ if data.step == self.last_U_step:
+ return
+ self.last_U_step = data.step
+ logging.info('Updating U plots')
+ # Put data in the controller plots.
+ if hasattr(self, 'Usurf'):
+ for surf in self.Usurf:
+ surf.remove()
+
+ plots = [
+ (data.lqr_grid_U, 'lqr'),
+ (data.pi_grid_U, 'pi'),
+ ((data.lqr_grid_U - data.pi_grid_U), 'error'),
+ ]
+
+ self.Usurf = [
+ self.Uax[i].plot_surface(
+ data.grid_X,
+ data.grid_Y,
+ Z,
+ cmap="magma",
+ label=label,
+ ) for i, (Z, label) in enumerate(plots)
+ ]
+
+ for axis in self.Uax:
+ axis.legend()
+ axis.relim()
+ axis.autoscale_view()
+
+ return self.Usurf
+
+
+def main(argv):
+ if len(argv) > 1:
+ raise absl.app.UsageError('Too many command-line arguments.')
+
+ tf.config.experimental.set_visible_devices([], 'GPU')
+
+ lock = threading.Lock()
+
+ # Load data.
+ data = generate_data()
+
+ plotter = Plotter(data)
+
+ # Event for shutting down the thread.
+ shutdown = threading.Event()
+
+ # Thread to grab new data periodically.
+ def do_update():
+ while True:
+ nonlocal data
+
+ my_data = generate_data(data.step)
+
+ if my_data is not None:
+ with lock:
+ data = my_data
+
+ if shutdown.wait(timeout=3):
+ return
+
+ update_thread = threading.Thread(target=do_update)
+ update_thread.start()
+
+ # Now, update each of the plots every second with the new data.
+ def update0(frame):
+ with lock:
+ my_data = data
+
+ return plotter.update_trajectory_plot(my_data)
+
+ def update1(frame):
+ with lock:
+ my_data = data
+
+ return plotter.update_cost_plot(my_data)
+
+ def update2(frame):
+ with lock:
+ my_data = data
+
+ return plotter.update_U_plot(my_data)
+
+ animation0 = FuncAnimation(plotter.fig0, update0, interval=1000)
+ animation1 = FuncAnimation(plotter.costfig, update1, interval=1000)
+ animation2 = FuncAnimation(plotter.Ufig, update2, interval=1000)
+
+ pyplot.show()
+
+ shutdown.set()
+ update_thread.join()
+
+
+if __name__ == '__main__':
+ absl.flags.mark_flags_as_required(['workdir'])
+ absl.app.run(main)
diff --git a/frc971/control_loops/swerve/velocity_controller/main.py b/frc971/control_loops/swerve/velocity_controller/main.py
new file mode 100644
index 0000000..7ef687c
--- /dev/null
+++ b/frc971/control_loops/swerve/velocity_controller/main.py
@@ -0,0 +1,72 @@
+import os
+
+# Setup XLA first.
+os.environ['XLA_FLAGS'] = ' '.join([
+ # Teach it where to find CUDA
+ '--xla_gpu_cuda_data_dir=/usr/lib/cuda',
+ # Use up to 20 cores
+ #'--xla_force_host_platform_device_count=6',
+ # Dump XLA to /tmp/foo to aid debugging
+ #'--xla_dump_to=/tmp/foo',
+ #'--xla_gpu_enable_command_buffer='
+ # Dump sharding
+ #"--xla_dump_to=/tmp/foo",
+ #"--xla_dump_hlo_pass_re=spmd|propagation"
+])
+os.environ['JAX_PLATFORMS'] = 'cuda,cpu'
+os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'true'
+os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.50'
+
+from absl import app
+from absl import flags
+from absl import logging
+from clu import platform
+import jax
+import tensorflow as tf
+from frc971.control_loops.swerve import jax_dynamics
+
+jax._src.deprecations.accelerate('tracer-hash')
+# Enable the compilation cache
+jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
+jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
+jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
+jax.config.update('jax_threefry_partitionable', True)
+
+import train
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string('workdir', None, 'Directory to store model data.')
+
+
+def main(argv):
+ if len(argv) > 1:
+ raise app.UsageError('Too many command-line arguments.')
+
+ # Hide any GPUs from TensorFlow. Otherwise it might reserve memory.
+ tf.config.experimental.set_visible_devices([], 'GPU')
+
+ logging.info('JAX process: %d / %d', jax.process_index(),
+ jax.process_count())
+ logging.info('JAX local devices: %r', jax.local_devices())
+
+ # Add a note so that we can tell which task is which JAX host.
+ # (Depending on the platform task 0 is not guaranteed to be host 0)
+ platform.work_unit().set_task_status(
+ f'process_index: {jax.process_index()}, '
+ f'process_count: {jax.process_count()}')
+ platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
+ FLAGS.workdir, 'workdir')
+
+ logging.info(
+ 'Visualize results with: bazel run -c opt @pip_deps_tensorboard//:rules_python_wheel_entry_point_tensorboard -- --logdir %s',
+ FLAGS.workdir,
+ )
+
+ physics_constants = jax_dynamics.Coefficients()
+ state = train.train(FLAGS.workdir, physics_constants)
+
+
+if __name__ == '__main__':
+ flags.mark_flags_as_required(['workdir'])
+ app.run(main)
diff --git a/frc971/control_loops/swerve/velocity_controller/model.py b/frc971/control_loops/swerve/velocity_controller/model.py
new file mode 100644
index 0000000..ae67efe
--- /dev/null
+++ b/frc971/control_loops/swerve/velocity_controller/model.py
@@ -0,0 +1,447 @@
+from __future__ import annotations
+import flax
+import flashbax
+from typing import Any
+import dataclasses
+import absl
+from absl import logging
+import numpy
+import jax
+from flax import linen as nn
+from jaxtyping import Array, ArrayLike
+import optax
+from flax.training import train_state
+from jax.experimental import mesh_utils
+from jax.sharding import Mesh, PartitionSpec, NamedSharding
+from frc971.control_loops.swerve import jax_dynamics
+from frc971.control_loops.swerve import dynamics
+from frc971.control_loops.swerve.velocity_controller import physics
+from frc971.control_loops.swerve.velocity_controller import experience_buffer
+
+from flax.typing import PRNGKey
+
+FLAGS = absl.flags.FLAGS
+
+absl.flags.DEFINE_integer(
+ 'num_agents',
+ default=10,
+ help='Training batch size.',
+)
+
+absl.flags.DEFINE_float(
+ 'q_learning_rate',
+ default=0.002,
+ help='Training learning rate.',
+)
+
+absl.flags.DEFINE_float(
+ 'final_q_learning_rate',
+ default=0.00002,
+ help='Training learning rate.',
+)
+
+absl.flags.DEFINE_float(
+ 'pi_learning_rate',
+ default=0.002,
+ help='Training learning rate.',
+)
+
+absl.flags.DEFINE_float(
+ 'final_pi_learning_rate',
+ default=0.00002,
+ help='Training learning rate.',
+)
+
+absl.flags.DEFINE_integer(
+ 'replay_size',
+ default=2000000,
+ help='Number of steps to save in our replay buffer',
+)
+
+absl.flags.DEFINE_integer(
+ 'batch_size',
+ default=10000,
+ help='Batch size for learning Q and pi',
+)
+
+HIDDEN_WEIGHTS = 256
+
+LOG_STD_MIN = -20
+LOG_STD_MAX = 2
+
+
+def gaussian_likelihood(noise: ArrayLike, log_std: ArrayLike):
+ pre_sum = -0.5 * (noise**2 + 2 * log_std + jax.numpy.log(2 * jax.numpy.pi))
+
+ if len(pre_sum.shape) > 1:
+ return jax.numpy.sum(pre_sum, keepdims=True, axis=1)
+ else:
+ return jax.numpy.sum(pre_sum, keepdims=True)
+
+
+class SquashedGaussianMLPActor(nn.Module):
+ """Actor model."""
+
+ # Number of dimensions in the action space
+ action_space: int = 8
+
+ hidden_sizes: list[int] = dataclasses.field(
+ default_factory=lambda: [HIDDEN_WEIGHTS, HIDDEN_WEIGHTS])
+
+ # Activation function
+ activation: Callable = nn.activation.relu
+
+ # Max action we can apply
+ action_limit: float = 30.0
+
+ @nn.compact
+ def __call__(self,
+ observations: ArrayLike,
+ deterministic: bool = False,
+ rng: PRNGKey | None = None):
+ x = observations
+ # Apply the dense layers
+ for i, hidden_size in enumerate(self.hidden_sizes):
+ x = nn.Dense(
+ name=f'denselayer{i}',
+ features=hidden_size,
+ )(x)
+ x = self.activation(x)
+
+ # Average policy is a dense function of the space.
+ mu = nn.Dense(
+ features=self.action_space,
+ name='mu',
+ kernel_init=nn.initializers.zeros,
+ )(x)
+
+ log_std_layer = nn.Dense(features=self.action_space,
+ name='log_std_layer')(x)
+
+ # Clip the log of the standard deviation in a soft manner.
+ log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (
+ flax.linen.activation.tanh(log_std_layer) + 1)
+
+ std = jax.numpy.exp(log_std)
+
+ if rng is None:
+ rng = self.make_rng('pi')
+
+ # Grab a random sample
+ random_sample = jax.random.normal(rng, shape=std.shape)
+
+ if deterministic:
+ # We are testing the optimal policy, just use the mean.
+ pi_action = mu
+ else:
+ # Use the reparameterization trick. Adjust the unit gausian with
+ # something we can solve for to get the desired noise.
+ pi_action = random_sample * std + mu
+
+ logp_pi = gaussian_likelihood(random_sample, log_std)
+ # Adjustment to log prob
+ # NOTE: This formula is a little bit magic. To get an understanding of where it
+ # comes from, check out the original SAC paper (arXiv 1801.01290) and look in
+ # appendix C. This is a more numerically-stable equivalent to Eq 21.
+ delta = (2.0 * (jax.numpy.log(2.0) - pi_action -
+ flax.linen.softplus(-2.0 * pi_action)))
+
+ if len(delta.shape) > 1:
+ delta = jax.numpy.sum(delta, keepdims=True, axis=1)
+ else:
+ delta = jax.numpy.sum(delta, keepdims=True)
+
+ logp_pi = logp_pi - delta
+
+ # Now, saturate the action to the limit using tanh
+ pi_action = self.action_limit * flax.linen.activation.tanh(pi_action)
+
+ return pi_action, logp_pi, self.action_limit * std, random_sample
+
+
+class MLPQFunction(nn.Module):
+
+ # Number and size of the hidden layers.
+ hidden_sizes: list[int] = dataclasses.field(
+ default_factory=lambda: [HIDDEN_WEIGHTS, HIDDEN_WEIGHTS])
+
+ activation: Callable = nn.activation.tanh
+
+ @nn.compact
+ def __call__(self, observations, actions):
+ # Estimate Q with a simple multi layer dense network.
+ x = jax.numpy.hstack((observations, actions))
+ for i, hidden_size in enumerate(self.hidden_sizes):
+ x = nn.Dense(
+ name=f'denselayer{i}',
+ features=hidden_size,
+ )(x)
+ x = self.activation(x)
+
+ x = nn.Dense(name=f'q', features=1,
+ kernel_init=nn.initializers.zeros)(x)
+
+ return x
+
+
+class TrainState(flax.struct.PyTreeNode):
+ physics_constants: jax_dynamics.CoefficientsType = flax.struct.field(
+ pytree_node=False)
+
+ step: int | jax.Array = flax.struct.field(pytree_node=True)
+ substep: int | jax.Array = flax.struct.field(pytree_node=True)
+
+ params: flax.core.FrozenDict[str, typing.Any] = flax.struct.field(
+ pytree_node=True)
+
+ target_params: flax.core.FrozenDict[str, typing.Any] = flax.struct.field(
+ pytree_node=True)
+
+ pi_apply_fn: Callable = flax.struct.field(pytree_node=False)
+ q1_apply_fn: Callable = flax.struct.field(pytree_node=False)
+ q2_apply_fn: Callable = flax.struct.field(pytree_node=False)
+
+ pi_tx: optax.GradientTransformation = flax.struct.field(pytree_node=False)
+ pi_opt_state: optax.OptState = flax.struct.field(pytree_node=True)
+ q_tx: optax.GradientTransformation = flax.struct.field(pytree_node=False)
+ q_opt_state: optax.OptState = flax.struct.field(pytree_node=True)
+
+ alpha_tx: optax.GradientTransformation = flax.struct.field(
+ pytree_node=False)
+ alpha_opt_state: optax.OptState = flax.struct.field(pytree_node=True)
+
+ target_entropy: float = flax.struct.field(pytree_node=True)
+
+ mesh: Mesh = flax.struct.field(pytree_node=False)
+ sharding: NamedSharding = flax.struct.field(pytree_node=False)
+ replicated_sharding: NamedSharding = flax.struct.field(pytree_node=False)
+
+ replay_buffer: flashbax.buffers.trajectory_buffer.TrajectoryBuffer = flax.struct.field(
+ pytree_node=False)
+
+ def pi_apply(self,
+ rng: PRNGKey,
+ params,
+ observation,
+ deterministic: bool = False):
+ return self.pi_apply_fn({'params': params['pi']},
+ physics.unwrap_angles(observation),
+ deterministic=deterministic,
+ rngs={'pi': rng})
+
+ def q1_apply(self, params, observation, action):
+ return self.q1_apply_fn({'params': params['q1']},
+ physics.unwrap_angles(observation), action)
+
+ def q2_apply(self, params, observation, action):
+ return self.q2_apply_fn({'params': params['q2']},
+ physics.unwrap_angles(observation), action)
+
+ def pi_apply_gradients(self, step, grads):
+ updates, new_pi_opt_state = self.pi_tx.update(grads, self.pi_opt_state,
+ self.params)
+ new_params = optax.apply_updates(self.params, updates)
+
+ return self.replace(
+ step=step,
+ substep=jax.lax.select(step != self.step, 0, self.substep + 1),
+ params=new_params,
+ pi_opt_state=new_pi_opt_state,
+ )
+
+ def q_apply_gradients(self, step, grads):
+ updates, new_q_opt_state = self.q_tx.update(grads, self.q_opt_state,
+ self.params)
+ new_params = optax.apply_updates(self.params, updates)
+
+ return self.replace(
+ step=step,
+ substep=jax.lax.select(step != self.step, 0, self.substep + 1),
+ params=new_params,
+ q_opt_state=new_q_opt_state,
+ )
+
+ def target_apply_gradients(self, step):
+ new_target_params = optax.incremental_update(self.params,
+ self.target_params,
+ 1 - FLAGS.polyak)
+
+ return self.replace(
+ step=step,
+ substep=jax.lax.select(step != self.step, 0, self.substep + 1),
+ target_params=new_target_params,
+ )
+
+ def alpha_apply_gradients(self, step, grads):
+ updates, new_alpha_opt_state = self.alpha_tx.update(
+ grads, self.alpha_opt_state, self.params)
+ new_params = optax.apply_updates(self.params, updates)
+
+ return self.replace(
+ step=step,
+ substep=jax.lax.select(step != self.step, 0, self.substep + 1),
+ params=new_params,
+ alpha_opt_state=new_alpha_opt_state,
+ )
+
+ def update_step(self, step):
+ return self.replace(
+ step=step,
+ substep=jax.lax.select(step != self.step, 0, self.substep + 1),
+ )
+
+ @classmethod
+ def create(cls, *, physics_constants: jax_dynamics.CoefficientsType,
+ params, pi_tx, q_tx, alpha_tx, pi_apply_fn, q1_apply_fn,
+ q2_apply_fn, **kwargs):
+ """Creates a new instance with ``step=0`` and initialized ``opt_state``."""
+
+ pi_tx = optax.multi_transform(
+ {
+ 'train': pi_tx,
+ 'freeze': optax.set_to_zero()
+ },
+ param_labels=flax.traverse_util.path_aware_map(
+ lambda path, x: 'train'
+ if path[0] == 'pi' else 'freeze', params),
+ )
+ pi_opt_state = pi_tx.init(params)
+
+ q_tx = optax.multi_transform(
+ {
+ 'train': q_tx,
+ 'freeze': optax.set_to_zero()
+ },
+ param_labels=flax.traverse_util.path_aware_map(
+ lambda path, x: 'train'
+ if path[0] == 'q1' or path[0] == 'q2' else 'freeze', params),
+ )
+ q_opt_state = q_tx.init(params)
+
+ alpha_tx = optax.multi_transform(
+ {
+ 'train': alpha_tx,
+ 'freeze': optax.set_to_zero()
+ },
+ param_labels=flax.traverse_util.path_aware_map(
+ lambda path, x: 'train'
+ if path[0] == 'logalpha' else 'freeze', params),
+ )
+ alpha_opt_state = alpha_tx.init(params)
+
+ mesh = Mesh(
+ devices=mesh_utils.create_device_mesh(len(jax.devices())),
+ axis_names=('batch', ),
+ )
+ print('Devices:', jax.devices())
+ sharding = NamedSharding(mesh, PartitionSpec('batch'))
+ replicated_sharding = NamedSharding(mesh, PartitionSpec())
+
+ replay_buffer = experience_buffer.make_experience_buffer(
+ num_agents=FLAGS.num_agents,
+ sample_batch_size=FLAGS.batch_size,
+ length=FLAGS.replay_size)
+
+ result = cls(
+ physics_constants=physics_constants,
+ step=0,
+ substep=0,
+ params=params,
+ target_params=params,
+ q1_apply_fn=q1_apply_fn,
+ q2_apply_fn=q2_apply_fn,
+ pi_apply_fn=pi_apply_fn,
+ pi_tx=pi_tx,
+ pi_opt_state=pi_opt_state,
+ q_tx=q_tx,
+ q_opt_state=q_opt_state,
+ alpha_tx=alpha_tx,
+ alpha_opt_state=alpha_opt_state,
+ target_entropy=-physics.NUM_STATES,
+ mesh=mesh,
+ sharding=sharding,
+ replicated_sharding=replicated_sharding,
+ replay_buffer=replay_buffer,
+ )
+
+ return jax.device_put(result, replicated_sharding)
+
+
+def create_train_state(rng, physics_constants: jax_dynamics.CoefficientsType,
+ q_learning_rate, pi_learning_rate):
+ """Creates initial `TrainState`."""
+ pi = SquashedGaussianMLPActor(action_space=physics.NUM_OUTPUTS)
+ # We want q1 and q2 to have different network architectures so they pick up differnet things.
+ q1 = MLPQFunction(activation=nn.activation.gelu, hidden_sizes=[128, 256])
+ q2 = MLPQFunction(activation=nn.activation.gelu, hidden_sizes=[256, 128])
+
+ @jax.jit
+ def init_params(rng):
+ pi_rng, q1_rng, q2_rng = jax.random.split(rng, num=3)
+
+ pi_params = pi.init(
+ pi_rng,
+ jax.numpy.ones([physics.NUM_UNWRAPPED_STATES]),
+ )['params']
+ q1_params = q1.init(
+ q1_rng,
+ jax.numpy.ones([physics.NUM_UNWRAPPED_STATES]),
+ jax.numpy.ones([physics.NUM_OUTPUTS]),
+ )['params']
+ q2_params = q2.init(
+ q2_rng,
+ jax.numpy.ones([physics.NUM_UNWRAPPED_STATES]),
+ jax.numpy.ones([physics.NUM_OUTPUTS]),
+ )['params']
+
+ if FLAGS.alpha < 0.0:
+ logalpha = 0.0
+ else:
+ logalpha = jax.numpy.log(FLAGS.alpha)
+
+ return {
+ 'q1': q1_params,
+ 'q2': q2_params,
+ 'pi': pi_params,
+ 'logalpha': logalpha,
+ }
+
+ pi_tx = optax.sgd(learning_rate=pi_learning_rate)
+ q_tx = optax.sgd(learning_rate=q_learning_rate)
+ alpha_tx = optax.sgd(learning_rate=q_learning_rate)
+
+ result = TrainState.create(
+ physics_constants=physics_constants,
+ params=init_params(rng),
+ pi_tx=pi_tx,
+ q_tx=q_tx,
+ alpha_tx=alpha_tx,
+ pi_apply_fn=pi.apply,
+ q1_apply_fn=q1.apply,
+ q2_apply_fn=q2.apply,
+ )
+
+ return result
+
+
+def create_learning_rate_fn(
+ base_learning_rate: float,
+ final_learning_rate: float,
+):
+ """Create learning rate schedule."""
+ warmup_fn = optax.linear_schedule(
+ init_value=base_learning_rate,
+ end_value=base_learning_rate,
+ transition_steps=FLAGS.warmup_steps,
+ )
+
+ cosine_epochs = max(FLAGS.steps - FLAGS.warmup_steps, 1)
+ cosine_fn = optax.cosine_decay_schedule(init_value=base_learning_rate,
+ decay_steps=cosine_epochs,
+ alpha=final_learning_rate)
+
+ schedule_fn = optax.join_schedules(
+ schedules=[warmup_fn, cosine_fn],
+ boundaries=[FLAGS.warmup_steps],
+ )
+ return schedule_fn
diff --git a/frc971/control_loops/swerve/velocity_controller/physics.py b/frc971/control_loops/swerve/velocity_controller/physics.py
new file mode 100644
index 0000000..9454f14
--- /dev/null
+++ b/frc971/control_loops/swerve/velocity_controller/physics.py
@@ -0,0 +1,69 @@
+import jax
+from functools import partial
+from frc971.control_loops.swerve import dynamics
+from absl import logging
+from frc971.control_loops.swerve import jax_dynamics
+
+
+@partial(jax.jit, static_argnums=[0])
+def xdot_physics(physics_constants: jax_dynamics.CoefficientsType, X, U):
+ A = jax.numpy.array([[0., 1.], [0., -36.85154548]])
+ B = jax.numpy.array([[0.], [56.08534375]])
+
+ return A @ X + B @ U
+
+
+def unwrap_angles(X):
+ return X
+
+
+@jax.jit
+def swerve_cost(X: jax.typing.ArrayLike, U: jax.typing.ArrayLike,
+ goal: jax.typing.ArrayLike):
+
+ Q = jax.numpy.array([[2.77777778, 0.], [0., 0.01]])
+ R = jax.numpy.array([[0.00694444]])
+
+ return (X).T @ Q @ (X) + U.T @ R @ U
+
+
+@partial(jax.jit, static_argnums=[0])
+def integrate_dynamics(physics_constants: jax_dynamics.CoefficientsType, X, U):
+ m = 2 # RK4 steps per interval
+ dt = 0.005 / m
+
+ def iteration(i, X):
+ weights = jax.numpy.array([[0.0, 0.5, 0.5, 1.0], [1.0, 2.0, 2.0, 1.0]])
+
+ def rk_iteration(i, val):
+ kx_previous, weighted_sum = val
+ kx = xdot_physics(physics_constants,
+ X + dt * weights[0, i] * kx_previous, U)
+ weighted_sum += dt * weights[1, i] * kx / 6.0
+ return (kx, weighted_sum)
+
+ return jax.lax.fori_loop(lower=0,
+ upper=4,
+ body_fun=rk_iteration,
+ init_val=(X, X))[1]
+
+ return jax.lax.fori_loop(lower=0, upper=m, body_fun=iteration, init_val=X)
+
+
+state_cost = swerve_cost
+NUM_STATES = 2
+NUM_UNWRAPPED_STATES = 2
+NUM_OUTPUTS = 1
+ACTION_LIMIT = 10.0
+
+
+def random_states(rng, dimensions=None):
+ rng1, rng2 = jax.random.split(rng)
+
+ return jax.numpy.hstack(
+ (jax.random.uniform(rng1, (dimensions or FLAGS.num_agents, 1),
+ minval=-1.0,
+ maxval=1.0),
+ jax.random.uniform(rng2, (dimensions or FLAGS.num_agents, 1),
+ minval=-10.0,
+ maxval=10.0)))
diff --git a/frc971/control_loops/swerve/velocity_controller/train.py b/frc971/control_loops/swerve/velocity_controller/train.py
new file mode 100644
index 0000000..5002eee
--- /dev/null
+++ b/frc971/control_loops/swerve/velocity_controller/train.py
@@ -0,0 +1,511 @@
+import absl
+import time
+import collections
+from absl import logging
+import flax
+import matplotlib
+from matplotlib import pyplot
+from flax import linen as nn
+from flax.training import train_state
+from flax.training import checkpoints
+import jax
+import inspect
+import aim
+import jax.numpy as jnp
+import ml_collections
+import numpy as np
+import optax
+import numpy
+from frc971.control_loops.swerve import jax_dynamics
+from functools import partial
+import flashbax
+from jax.experimental.ode import odeint
+from frc971.control_loops.swerve import dynamics
+import orbax.checkpoint
+from frc971.control_loops.swerve.velocity_controller.model import *
+from frc971.control_loops.swerve.velocity_controller.physics import *
+
+numpy.set_printoptions(linewidth=200, )
+
+FLAGS = absl.flags.FLAGS
+
+absl.flags.DEFINE_integer(
+ 'horizon',
+ default=25,
+ help='MPC horizon',
+)
+
+absl.flags.DEFINE_integer(
+ 'start_steps',
+ default=10000,
+ help='Number of steps to randomly sample before using the policy',
+)
+
+absl.flags.DEFINE_integer(
+ 'steps',
+ default=400000,
+ help='Number of steps to run and train the agent',
+)
+
+absl.flags.DEFINE_integer(
+ 'warmup_steps',
+ default=300000,
+ help='Number of steps to warm up training',
+)
+
+absl.flags.DEFINE_float(
+ 'gamma',
+ default=0.999,
+ help='Future discount.',
+)
+
+absl.flags.DEFINE_float(
+ 'alpha',
+ default=0.2,
+ help='Entropy. If negative, automatically solve for it.',
+)
+
+absl.flags.DEFINE_float(
+ 'polyak',
+ default=0.995,
+ help='Time constant in polyak averaging for the target network.',
+)
+
+absl.flags.DEFINE_bool(
+ 'debug_nan',
+ default=False,
+ help='If true, explode on any NaNs found, and print them.',
+)
+
+
+def random_actions(rng, dimensions=None):
+ """Produces a uniformly random action in the action space."""
+ return jax.random.uniform(
+ rng,
+ (dimensions or FLAGS.num_agents, physics.NUM_OUTPUTS),
+ minval=-physics.ACTION_LIMIT,
+ maxval=physics.ACTION_LIMIT,
+ )
+
+
+def save_checkpoint(state: TrainState, workdir: str):
+ """Saves a checkpoint in the workdir."""
+ # TODO(austin): use orbax directly.
+ step = int(state.step)
+ logging.info('Saving checkpoint step %d.', step)
+ checkpoints.save_checkpoint_multiprocess(workdir, state, step, keep=10)
+
+
+def restore_checkpoint(state: TrainState, workdir: str):
+ """Restores the latest checkpoint from the workdir."""
+ return checkpoints.restore_checkpoint(workdir, state)
+
+
+def has_nan(x):
+ if isinstance(x, jnp.ndarray):
+ return jnp.any(jnp.isnan(x))
+ else:
+ return jnp.any(
+ jax.numpy.array([has_nan(v)
+ for v in jax.tree_util.tree_leaves(x)]))
+
+
+def print_nan(step, x):
+ if not FLAGS.debug_nan:
+ return
+
+ caller = inspect.getframeinfo(inspect.stack()[1][0])
+ # TODO(austin): It is helpful to sometimes start printing at a certain step.
+ jax.lax.cond(
+ has_nan(x), lambda: jax.debug.print(caller.filename +
+ ':{l} step {s} isnan(X) -> {x}',
+ l=caller.lineno,
+ s=step,
+ x=x), lambda: None)
+
+
+@jax.jit
+def compute_loss_q(state: TrainState, rng: PRNGKey, params, data: ArrayLike):
+ """Computes the Soft Actor-Critic loss for Q1 and Q2."""
+ observations1 = data['observations1']
+ actions = data['actions']
+ rewards = data['rewards']
+ observations2 = data['observations2']
+
+ # Compute the ending actions from the current network.
+ action2, logp_pi2, _, _ = state.pi_apply(rng=rng,
+ params=params,
+ observation=observations2)
+
+ # Compute target network Q values
+ q1_pi_target = state.q1_apply(state.target_params, observations2, action2)
+ q2_pi_target = state.q2_apply(state.target_params, observations2, action2)
+ q_pi_target = jax.numpy.minimum(q1_pi_target, q2_pi_target)
+
+ alpha = jax.numpy.exp(params['logalpha'])
+
+ # Now we can compute the Bellman backup
+ bellman_backup = jax.lax.stop_gradient(rewards + FLAGS.gamma *
+ (q_pi_target - alpha * logp_pi2))
+
+ # Compute the starting Q values from the Q network being optimized.
+ q1 = state.q1_apply(params, observations1, actions)
+ q2 = state.q2_apply(params, observations1, actions)
+
+ # Mean squared error loss against Bellman backup
+ q1_loss = ((q1 - bellman_backup)**2).mean()
+ q2_loss = ((q2 - bellman_backup)**2).mean()
+ return q1_loss + q2_loss
+
+
+@jax.jit
+def compute_batched_loss_q(state: TrainState, rng: PRNGKey, params,
+ data: ArrayLike):
+
+ def bound_compute_loss_q(rng, data):
+ return compute_loss_q(state, rng, params, data)
+
+ return jax.vmap(bound_compute_loss_q)(
+ jax.random.split(rng, FLAGS.num_agents),
+ data,
+ ).mean()
+
+
+@jax.jit
+def compute_loss_pi(state: TrainState, rng: PRNGKey, params, data: ArrayLike):
+ """Computes the Soft Actor-Critic loss for pi."""
+ observations1 = data['observations1']
+ # TODO(austin): We've got differentiable policy and differentiable physics. Can we use those here? Have Q learn the future, not the current step?
+
+ # Compute the action
+ pi, logp_pi, _, _ = state.pi_apply(rng=rng,
+ params=params,
+ observation=observations1)
+ q1_pi = state.q1_apply(jax.lax.stop_gradient(params), observations1, pi)
+ q2_pi = state.q2_apply(jax.lax.stop_gradient(params), observations1, pi)
+
+ # And compute the Q of that action.
+ q_pi = jax.numpy.minimum(q1_pi, q2_pi)
+
+ alpha = jax.lax.stop_gradient(jax.numpy.exp(params['logalpha']))
+
+ # Compute the entropy-regularized policy loss
+ return (alpha * logp_pi - q_pi).mean()
+
+
+@jax.jit
+def compute_batched_loss_pi(state: TrainState, rng: PRNGKey, params,
+ data: ArrayLike):
+
+ def bound_compute_loss_pi(rng, data):
+ return compute_loss_pi(state, rng, params, data)
+
+ return jax.vmap(bound_compute_loss_pi)(
+ jax.random.split(rng, FLAGS.num_agents),
+ data,
+ ).mean()
+
+
+@jax.jit
+def compute_loss_alpha(state: TrainState, rng: PRNGKey, params,
+ data: ArrayLike):
+ """Computes the Soft Actor-Critic loss for alpha."""
+ observations1 = data['observations1']
+ pi, logp_pi, _, _ = jax.lax.stop_gradient(
+ state.pi_apply(rng=rng, params=params, observation=observations1))
+
+ return (-jax.numpy.exp(params['logalpha']) *
+ (logp_pi + state.target_entropy)).mean()
+
+
+@jax.jit
+def compute_batched_loss_alpha(state: TrainState, rng: PRNGKey, params,
+ data: ArrayLike):
+
+ def bound_compute_loss_alpha(rng, data):
+ return compute_loss_alpha(state, rng, params, data)
+
+ return jax.vmap(bound_compute_loss_alpha)(
+ jax.random.split(rng, FLAGS.num_agents),
+ data,
+ ).mean()
+
+
+@jax.jit
+def train_step(state: TrainState, data, action_data, update_rng: PRNGKey,
+ step: int) -> TrainState:
+ """Updates the parameters for Q, Pi, target Q, and alpha."""
+ update_rng, q_grad_rng = jax.random.split(update_rng)
+ print_nan(step, data)
+
+ # Update Q
+ q_grad_fn = jax.value_and_grad(
+ lambda params: compute_batched_loss_q(state, q_grad_rng, params, data))
+ q_loss, q_grads = q_grad_fn(state.params)
+ print_nan(step, q_loss)
+ print_nan(step, q_grads)
+
+ state = state.q_apply_gradients(step=step, grads=q_grads)
+
+ update_rng, pi_grad_rng = jax.random.split(update_rng)
+
+ # Update pi
+ pi_grad_fn = jax.value_and_grad(lambda params: compute_batched_loss_pi(
+ state, pi_grad_rng, params, action_data))
+ pi_loss, pi_grads = pi_grad_fn(state.params)
+
+ print_nan(step, pi_loss)
+ print_nan(step, pi_grads)
+
+ state = state.pi_apply_gradients(step=step, grads=pi_grads)
+
+ update_rng, alpha_grad_rng = jax.random.split(update_rng)
+
+ if FLAGS.alpha < 0.0:
+ # Update alpha
+ alpha_grad_fn = jax.value_and_grad(
+ lambda params: compute_batched_loss_alpha(state, alpha_grad_rng,
+ params, data))
+ alpha_loss, alpha_grads = alpha_grad_fn(state.params)
+ print_nan(step, alpha_loss)
+ print_nan(step, alpha_grads)
+ state = state.alpha_apply_gradients(step=step, grads=alpha_grads)
+ else:
+ alpha_loss = 0.0
+
+ return state, q_loss, pi_loss, alpha_loss
+
+
+@jax.jit
+def collect_experience(state: TrainState, replay_buffer_state, R: ArrayLike,
+ step_rng: PRNGKey, step):
+ """Collects experience by simulating."""
+ pi_rng = jax.random.fold_in(step_rng, step)
+ pi_rng, initialization_rng = jax.random.split(pi_rng)
+
+ observation = jax.lax.with_sharding_constraint(
+ physics.random_states(initialization_rng, FLAGS.num_agents),
+ state.sharding)
+
+ def loop(i, val):
+ """Runs 1 step of our simulation."""
+ observation, pi_rng, replay_buffer_state = val
+ pi_rng, action_rng = jax.random.split(pi_rng)
+ logging.info('Observation shape: %s', observation.shape)
+
+ def true_fn(i):
+ # We are at the beginning of the process, pick a random action.
+ return random_actions(action_rng, FLAGS.num_agents)
+
+ def false_fn(i):
+ # We are past the beginning of the process, use the trained network.
+ pi_action, logp_pi, std, random_sample = state.pi_apply(
+ rng=action_rng,
+ params=state.params,
+ observation=observation,
+ deterministic=False)
+ return pi_action
+
+ pi_action = jax.lax.cond(
+ step <= FLAGS.start_steps,
+ true_fn,
+ false_fn,
+ i,
+ )
+
+ # Compute the destination state.
+ observation2 = jax.vmap(lambda o, pi: physics.integrate_dynamics(
+ state.physics_constants, o, pi),
+ in_axes=(0, 0))(observation, pi_action)
+
+ # Soft Actor-Critic is designed to maximize reward. LQR minimizes
+ # cost. There is nothing which assumes anything about the sign of
+ # the reward, so use the negative of the cost.
+ reward = (-jax.vmap(state_cost)(observation2, pi_action, R)).reshape(
+ (FLAGS.num_agents, 1))
+
+ logging.info('Observation shape: %s', observation.shape)
+ logging.info('Observation2 shape: %s', observation2.shape)
+ logging.info('Action shape: %s', pi_action.shape)
+ logging.info('Action shape: %s', pi_action.shape)
+ replay_buffer_state = state.replay_buffer.add(
+ replay_buffer_state, {
+ 'observations1': observation,
+ 'observations2': observation2,
+ 'actions': pi_action,
+ 'rewards': reward,
+ })
+
+ return observation2, pi_rng, replay_buffer_state
+
+ # Run 1 horizon of simulation
+ final_observation, final_pi_rng, final_replay_buffer_state = jax.lax.fori_loop(
+ 0, FLAGS.horizon + 1, loop, (observation, pi_rng, replay_buffer_state))
+
+ return state, final_replay_buffer_state
+
+
+@jax.jit
+def update_gradients(rng: PRNGKey, state: TrainState, replay_buffer_state,
+ step: int):
+ rng, sample_rng = jax.random.split(rng)
+
+ action_data = state.replay_buffer.sample(replay_buffer_state, sample_rng)
+
+ def update_iteration(i, val):
+ rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, action_data = val
+ rng, sample_rng, update_rng = jax.random.split(rng, 3)
+
+ batch = state.replay_buffer.sample(replay_buffer_state, sample_rng)
+
+ print_nan(i, replay_buffer_state)
+ print_nan(i, batch)
+
+ state, q_loss, pi_loss, alpha_loss = train_step(
+ state,
+ data=batch.experience,
+ action_data=batch.experience,
+ update_rng=update_rng,
+ step=i)
+
+ return rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, action_data
+
+ rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, action_data = jax.lax.fori_loop(
+ step, step + FLAGS.horizon + 1, update_iteration,
+ (rng, state, 0.0, 0.0, 0.0, replay_buffer_state, action_data))
+
+ state = state.target_apply_gradients(step=state.step)
+
+ return rng, state, q_loss, pi_loss, alpha_loss
+
+
+def train(
+ workdir: str, physics_constants: jax_dynamics.CoefficientsType
+) -> train_state.TrainState:
+ """Trains a Soft Actor-Critic controller."""
+ rng = jax.random.key(0)
+ rng, r_rng = jax.random.split(rng)
+
+ run = aim.Run(repo='aim://127.0.0.1:53800')
+
+ run['hparams'] = {
+ 'q_learning_rate': FLAGS.q_learning_rate,
+ 'pi_learning_rate': FLAGS.pi_learning_rate,
+ 'batch_size': FLAGS.batch_size,
+ 'horizon': FLAGS.horizon,
+ 'warmup_steps': FLAGS.warmup_steps,
+ }
+
+ # Setup TrainState
+ rng, init_rng = jax.random.split(rng)
+ q_learning_rate = create_learning_rate_fn(
+ base_learning_rate=FLAGS.q_learning_rate,
+ final_learning_rate=FLAGS.final_q_learning_rate)
+ pi_learning_rate = create_learning_rate_fn(
+ base_learning_rate=FLAGS.pi_learning_rate,
+ final_learning_rate=FLAGS.final_pi_learning_rate)
+ state = create_train_state(
+ init_rng,
+ physics_constants,
+ q_learning_rate=q_learning_rate,
+ pi_learning_rate=pi_learning_rate,
+ )
+ state = restore_checkpoint(state, workdir)
+
+ state_sharding = nn.get_sharding(state, state.mesh)
+ logging.info(state_sharding)
+
+ # TODO(austin): Let the goal change.
+ R = jax.numpy.hstack((
+ jax.random.uniform(rng, (FLAGS.num_agents, 1), minval=0.9, maxval=1.1),
+ jax.random.uniform(rng, (FLAGS.num_agents, 1), minval=-0.1,
+ maxval=0.1),
+ jax.numpy.zeros((FLAGS.num_agents, 1)),
+ ))
+
+ replay_buffer_state = state.replay_buffer.init({
+ 'observations1':
+ jax.numpy.zeros((physics.NUM_STATES, )),
+ 'observations2':
+ jax.numpy.zeros((physics.NUM_STATES, )),
+ 'actions':
+ jax.numpy.zeros((physics.NUM_OUTPUTS, )),
+ 'rewards':
+ jax.numpy.zeros((1, )),
+ })
+
+ replay_buffer_state_sharding = nn.get_sharding(replay_buffer_state,
+ state.mesh)
+ logging.info(replay_buffer_state_sharding)
+
+ logging.info('Shape: %s, Sharding %s',
+ replay_buffer_state.experience['observations1'].shape,
+ replay_buffer_state.experience['observations1'].sharding)
+ logging.info('Shape: %s, Sharding %s',
+ replay_buffer_state.experience['observations2'].shape,
+ replay_buffer_state.experience['observations2'].sharding)
+ logging.info('Shape: %s, Sharding %s',
+ replay_buffer_state.experience['actions'].shape,
+ replay_buffer_state.experience['actions'].sharding)
+ logging.info('Shape: %s, Sharding %s',
+ replay_buffer_state.experience['rewards'].shape,
+ replay_buffer_state.experience['rewards'].sharding)
+
+ # Number of gradients to accumulate before doing decent.
+ update_after = FLAGS.batch_size // FLAGS.num_agents
+
+ @partial(jax.jit, donate_argnums=(1, 2))
+ def train_loop(state: TrainState, replay_buffer_state, rng: PRNGKey,
+ step: int):
+ rng, step_rng = jax.random.split(rng)
+ # Collect experience
+ state, replay_buffer_state = collect_experience(
+ state,
+ replay_buffer_state,
+ R,
+ step_rng,
+ step,
+ )
+
+ def nop(rng, state, replay_buffer_state, step):
+ return rng, state.update_step(step=step), 0.0, 0.0, 0.0
+
+ # Train
+ rng, state, q_loss, pi_loss, alpha_loss = jax.lax.cond(
+ step >= update_after, update_gradients, nop, rng, state,
+ replay_buffer_state, step)
+
+ return state, replay_buffer_state, rng, q_loss, pi_loss, alpha_loss
+
+ for step in range(0, FLAGS.steps, FLAGS.horizon):
+ state, replay_buffer_state, rng, q_loss, pi_loss, alpha_loss = train_loop(
+ state, replay_buffer_state, rng, step)
+
+ if FLAGS.debug_nan and has_nan(state.params):
+ logging.fatal('Nan params, aborting')
+
+ logging.info(
+ 'Step %s: q_loss=%s, pi_loss=%s, alpha_loss=%s, q_learning_rate=%s, pi_learning_rate=%s, alpha=%s',
+ step,
+ q_loss,
+ pi_loss,
+ alpha_loss,
+ q_learning_rate(step),
+ pi_learning_rate(step),
+ jax.numpy.exp(state.params['logalpha']),
+ )
+
+ run.track(
+ {
+ 'q_loss': float(q_loss),
+ 'pi_loss': float(pi_loss),
+ 'alpha_loss': float(alpha_loss),
+ 'alpha': float(jax.numpy.exp(state.params['logalpha']))
+ },
+ step=step)
+
+ if step % 1000 == 0 and step > update_after:
+ # TODO(austin): Simulate a rollout and accumulate the reward. How good are we doing?
+ save_checkpoint(state, workdir)
+
+ return state