Add an experience buffer and tests.
This sets us up to shard across the actor dimension and randomly sample
actions.
Change-Id: I9e5952ddd1765276766025731c89cbaa961830d5
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/velocity_controller/BUILD b/frc971/control_loops/swerve/velocity_controller/BUILD
new file mode 100644
index 0000000..5e06420
--- /dev/null
+++ b/frc971/control_loops/swerve/velocity_controller/BUILD
@@ -0,0 +1,24 @@
+package(default_visibility = ["//visibility:public"])
+
+py_library(
+ name = "experience_buffer",
+ srcs = [
+ "experience_buffer.py",
+ ],
+ deps = [
+ "@pip//flashbax",
+ "@pip//jax",
+ ],
+)
+
+py_test(
+ name = "experience_buffer_test",
+ srcs = [
+ "experience_buffer_test.py",
+ ],
+ deps = [
+ ":experience_buffer",
+ "@pip//flashbax",
+ "@pip//jax",
+ ],
+)
diff --git a/frc971/control_loops/swerve/velocity_controller/experience_buffer.py b/frc971/control_loops/swerve/velocity_controller/experience_buffer.py
new file mode 100644
index 0000000..1af3466
--- /dev/null
+++ b/frc971/control_loops/swerve/velocity_controller/experience_buffer.py
@@ -0,0 +1,85 @@
+import flashbax
+from flashbax.buffers.trajectory_buffer import TrajectoryBufferState, Experience, TrajectoryBufferSample
+import jax
+from jax.sharding import Mesh, PartitionSpec, NamedSharding
+from jax.experimental import mesh_utils
+from flax.typing import PRNGKey
+
+
+def make_experience_buffer(num_agents, sample_batch_size, length):
+ """Makes a random, sharded, fifo experience buffer."""
+ mesh = jax.sharding.Mesh(
+ devices=mesh_utils.create_device_mesh(len(jax.devices())),
+ axis_names=('batch', ),
+ )
+
+ # Shard all the data along the agents axis.
+ sharding = jax.sharding.NamedSharding(mesh, PartitionSpec('batch'))
+ replicated_sharding = jax.sharding.NamedSharding(mesh, PartitionSpec())
+
+ sample_batch_size = sample_batch_size // num_agents
+ trajectory_buffer = flashbax.make_trajectory_buffer(
+ max_length_time_axis=length // num_agents,
+ min_length_time_axis=1,
+ add_batch_size=num_agents,
+ sample_batch_size=sample_batch_size,
+ sample_sequence_length=1,
+ period=1,
+ )
+
+ def add_fn(state: TrajectoryBufferState,
+ batch: Experience) -> TrajectoryBufferState[Experience]:
+ # Squeeze the data to match the shape desired by flashbax.
+ batch_size, = flashbax.utils.get_tree_shape_prefix(batch, n_axes=1)
+ expanded_batch = jax.tree.map(
+ lambda x: x.reshape((batch_size, 1, *x.shape[1:])), batch)
+ return trajectory_buffer.add(state, expanded_batch)
+
+ def sample_fn(state: TrajectoryBufferState,
+ rng_key: PRNGKey) -> TrajectoryBufferSample[Experience]:
+ batch_size, = flashbax.utils.get_tree_shape_prefix(state.experience,
+ n_axes=1)
+
+ # Build up a RNG per actor so we can vmap the randomness.
+ sample_keys = jax.device_put(jax.random.split(rng_key, num=batch_size),
+ sharding)
+
+ # Now, randomly select the indices to sample for a single agent.
+ def single_item_indices(rng_key):
+ return jax.random.randint(
+ rng_key, (sample_batch_size, ), 0,
+ jax.lax.select(state.is_full, length // num_agents,
+ state.current_index))
+
+ # And do them all at once via vmap.
+ item_indices = jax.vmap(single_item_indices)(sample_keys)
+
+ # Actually sample them now, and vmap to do it for each agent.
+ vmap_sample_item_indices = jax.vmap(
+ lambda item_indices, x: x[item_indices])
+
+ # And apply it to the tree.
+ sampled_batch = jax.tree.map(
+ lambda x: vmap_sample_item_indices(item_indices, x),
+ state.experience)
+
+ return flashbax.buffers.trajectory_buffer.TrajectoryBufferSample(
+ experience=sampled_batch)
+
+ def init_fn(experience: Experience):
+ state = trajectory_buffer.init(experience)
+
+ # Push each element of the tree out across the devices to shard it.
+ sharded_experience = jax.tree_util.tree_map(
+ lambda x: jax.device_put(x, sharding), state.experience)
+
+ return flashbax.buffers.trajectory_buffer.TrajectoryBufferState(
+ experience=sharded_experience,
+ is_full=jax.device_put(state.is_full, replicated_sharding),
+ current_index=jax.device_put(state.current_index,
+ replicated_sharding),
+ )
+
+ return trajectory_buffer.replace(add=add_fn,
+ sample=sample_fn,
+ init=init_fn)
diff --git a/frc971/control_loops/swerve/velocity_controller/experience_buffer_test.py b/frc971/control_loops/swerve/velocity_controller/experience_buffer_test.py
new file mode 100644
index 0000000..016a56d
--- /dev/null
+++ b/frc971/control_loops/swerve/velocity_controller/experience_buffer_test.py
@@ -0,0 +1,126 @@
+import os
+
+os.environ['XLA_FLAGS'] = ' '.join([
+ # Teach it where to find CUDA
+ '--xla_gpu_cuda_data_dir=/usr/lib/cuda',
+ # Use up to 20 cores
+ '--xla_force_host_platform_device_count=2',
+ # Dump XLA to /tmp/foo to aid debugging
+ #'--xla_dump_to=/tmp/foo',
+ #'--xla_gpu_enable_command_buffer='
+])
+os.environ['JAX_PLATFORMS'] = 'cpu'
+#os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
+
+import experience_buffer
+import jax
+import numpy
+from jax.experimental import mesh_utils
+from jax.sharding import Mesh, PartitionSpec, NamedSharding
+import unittest
+
+
+class TestExperienceBuffer(unittest.TestCase):
+
+ def setUp(self):
+ self.mesh = jax.sharding.Mesh(
+ devices=mesh_utils.create_device_mesh(len(jax.devices())),
+ axis_names=('batch', ),
+ )
+
+ self.sharding = jax.sharding.NamedSharding(self.mesh,
+ PartitionSpec('batch'))
+ self.replicated_sharding = jax.sharding.NamedSharding(
+ self.mesh, PartitionSpec())
+
+ self.num_agents = 2
+ self.sample_batch_size = 10
+ self.length = 20
+
+ self.buffer = experience_buffer.make_experience_buffer(
+ self.num_agents, self.sample_batch_size, self.length)
+
+ def test_shape(self):
+ """Tests that the shapes coming out are right."""
+ buffer_state = self.buffer.init({
+ 'key': jax.numpy.zeros((2, )),
+ })
+
+ for i in range(self.sample_batch_size // self.num_agents):
+ buffer_state = self.buffer.add(
+ buffer_state, {
+ 'key':
+ jax.numpy.array(
+ [[i * 4, i * 4 + 1], [i * 4 + 2, i * 4 + 3]],
+ dtype=jax.numpy.float32)
+ })
+
+ rng = jax.random.key(0)
+
+ for i in range(2):
+ rng, sample_rng = jax.random.split(rng)
+ batch = self.buffer.sample(buffer_state, sample_rng)
+ self.assertEqual(batch.experience['key'].shape, (2, 5, 2))
+
+ def test_randomness(self):
+ """Tests that no sample is more or less likely."""
+ rng = jax.random.key(0)
+
+ # Adds an element to the buffer, and accumulates the sample.
+ def loop(i, val):
+ counts, buffer_state, rng = val
+
+ buffer_state = self.buffer.add(
+ buffer_state,
+ {'key': jax.numpy.array([[-i], [i]], dtype=jax.numpy.int32)})
+
+ rng, sample_rng = jax.random.split(rng)
+
+ def do_count(counts):
+ batch = self.buffer.sample(buffer_state, sample_rng)
+ for a in range(self.num_agents):
+ for s in range(5):
+ sampled_agent = jax.numpy.abs(batch.experience['key'][
+ a, s, 0]) % (self.length // self.num_agents)
+ prior = counts[a, sampled_agent]
+ counts = counts.at[a, sampled_agent].set(prior + 1)
+
+ return counts
+
+ # If we are full, start randomly picking and counting.
+ counts = jax.lax.cond(i >= self.length // self.num_agents,
+ do_count, lambda counts: counts, counts)
+
+ return counts, buffer_state, rng
+
+ @jax.jit
+ def doit(rng):
+ buffer_state = self.buffer.init({
+ 'key':
+ jax.numpy.zeros((1, ), dtype=jax.numpy.int32),
+ })
+
+ counts = numpy.zeros(
+ (self.num_agents, self.length // self.num_agents))
+
+ counts, buffer_state, rng = jax.lax.fori_loop(
+ 0,
+ 10000,
+ loop,
+ (jax.numpy.zeros(
+ (self.num_agents, self.length // self.num_agents)),
+ buffer_state, rng),
+ )
+ return counts
+
+ # Do this all in jax to make it fast. Many times speedup, including the JIT'ing.
+ counts = numpy.array(doit(rng), dtype=numpy.int32)
+ print(counts.min(), counts.max())
+
+ # Make sure things are decently balanced.
+ self.assertGreater(counts.min(), 4800)
+ self.assertLess(counts.max(), 5200)
+
+
+if __name__ == "__main__":
+ unittest.main()