Add an experience buffer and tests.

This sets us up to shard across the actor dimension and randomly sample
actions.

Change-Id: I9e5952ddd1765276766025731c89cbaa961830d5
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/velocity_controller/BUILD b/frc971/control_loops/swerve/velocity_controller/BUILD
new file mode 100644
index 0000000..5e06420
--- /dev/null
+++ b/frc971/control_loops/swerve/velocity_controller/BUILD
@@ -0,0 +1,24 @@
+package(default_visibility = ["//visibility:public"])
+
+py_library(
+    name = "experience_buffer",
+    srcs = [
+        "experience_buffer.py",
+    ],
+    deps = [
+        "@pip//flashbax",
+        "@pip//jax",
+    ],
+)
+
+py_test(
+    name = "experience_buffer_test",
+    srcs = [
+        "experience_buffer_test.py",
+    ],
+    deps = [
+        ":experience_buffer",
+        "@pip//flashbax",
+        "@pip//jax",
+    ],
+)
diff --git a/frc971/control_loops/swerve/velocity_controller/experience_buffer.py b/frc971/control_loops/swerve/velocity_controller/experience_buffer.py
new file mode 100644
index 0000000..1af3466
--- /dev/null
+++ b/frc971/control_loops/swerve/velocity_controller/experience_buffer.py
@@ -0,0 +1,85 @@
+import flashbax
+from flashbax.buffers.trajectory_buffer import TrajectoryBufferState, Experience, TrajectoryBufferSample
+import jax
+from jax.sharding import Mesh, PartitionSpec, NamedSharding
+from jax.experimental import mesh_utils
+from flax.typing import PRNGKey
+
+
+def make_experience_buffer(num_agents, sample_batch_size, length):
+    """Makes a random, sharded, fifo experience buffer."""
+    mesh = jax.sharding.Mesh(
+        devices=mesh_utils.create_device_mesh(len(jax.devices())),
+        axis_names=('batch', ),
+    )
+
+    # Shard all the data along the agents axis.
+    sharding = jax.sharding.NamedSharding(mesh, PartitionSpec('batch'))
+    replicated_sharding = jax.sharding.NamedSharding(mesh, PartitionSpec())
+
+    sample_batch_size = sample_batch_size // num_agents
+    trajectory_buffer = flashbax.make_trajectory_buffer(
+        max_length_time_axis=length // num_agents,
+        min_length_time_axis=1,
+        add_batch_size=num_agents,
+        sample_batch_size=sample_batch_size,
+        sample_sequence_length=1,
+        period=1,
+    )
+
+    def add_fn(state: TrajectoryBufferState,
+               batch: Experience) -> TrajectoryBufferState[Experience]:
+        # Squeeze the data to match the shape desired by flashbax.
+        batch_size, = flashbax.utils.get_tree_shape_prefix(batch, n_axes=1)
+        expanded_batch = jax.tree.map(
+            lambda x: x.reshape((batch_size, 1, *x.shape[1:])), batch)
+        return trajectory_buffer.add(state, expanded_batch)
+
+    def sample_fn(state: TrajectoryBufferState,
+                  rng_key: PRNGKey) -> TrajectoryBufferSample[Experience]:
+        batch_size, = flashbax.utils.get_tree_shape_prefix(state.experience,
+                                                           n_axes=1)
+
+        # Build up a RNG per actor so we can vmap the randomness.
+        sample_keys = jax.device_put(jax.random.split(rng_key, num=batch_size),
+                                     sharding)
+
+        # Now, randomly select the indices to sample for a single agent.
+        def single_item_indices(rng_key):
+            return jax.random.randint(
+                rng_key, (sample_batch_size, ), 0,
+                jax.lax.select(state.is_full, length // num_agents,
+                               state.current_index))
+
+        # And do them all at once via vmap.
+        item_indices = jax.vmap(single_item_indices)(sample_keys)
+
+        # Actually sample them now, and vmap to do it for each agent.
+        vmap_sample_item_indices = jax.vmap(
+            lambda item_indices, x: x[item_indices])
+
+        # And apply it to the tree.
+        sampled_batch = jax.tree.map(
+            lambda x: vmap_sample_item_indices(item_indices, x),
+            state.experience)
+
+        return flashbax.buffers.trajectory_buffer.TrajectoryBufferSample(
+            experience=sampled_batch)
+
+    def init_fn(experience: Experience):
+        state = trajectory_buffer.init(experience)
+
+        # Push each element of the tree out across the devices to shard it.
+        sharded_experience = jax.tree_util.tree_map(
+            lambda x: jax.device_put(x, sharding), state.experience)
+
+        return flashbax.buffers.trajectory_buffer.TrajectoryBufferState(
+            experience=sharded_experience,
+            is_full=jax.device_put(state.is_full, replicated_sharding),
+            current_index=jax.device_put(state.current_index,
+                                         replicated_sharding),
+        )
+
+    return trajectory_buffer.replace(add=add_fn,
+                                     sample=sample_fn,
+                                     init=init_fn)
diff --git a/frc971/control_loops/swerve/velocity_controller/experience_buffer_test.py b/frc971/control_loops/swerve/velocity_controller/experience_buffer_test.py
new file mode 100644
index 0000000..016a56d
--- /dev/null
+++ b/frc971/control_loops/swerve/velocity_controller/experience_buffer_test.py
@@ -0,0 +1,126 @@
+import os
+
+os.environ['XLA_FLAGS'] = ' '.join([
+    # Teach it where to find CUDA
+    '--xla_gpu_cuda_data_dir=/usr/lib/cuda',
+    # Use up to 20 cores
+    '--xla_force_host_platform_device_count=2',
+    # Dump XLA to /tmp/foo to aid debugging
+    #'--xla_dump_to=/tmp/foo',
+    #'--xla_gpu_enable_command_buffer='
+])
+os.environ['JAX_PLATFORMS'] = 'cpu'
+#os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
+
+import experience_buffer
+import jax
+import numpy
+from jax.experimental import mesh_utils
+from jax.sharding import Mesh, PartitionSpec, NamedSharding
+import unittest
+
+
+class TestExperienceBuffer(unittest.TestCase):
+
+    def setUp(self):
+        self.mesh = jax.sharding.Mesh(
+            devices=mesh_utils.create_device_mesh(len(jax.devices())),
+            axis_names=('batch', ),
+        )
+
+        self.sharding = jax.sharding.NamedSharding(self.mesh,
+                                                   PartitionSpec('batch'))
+        self.replicated_sharding = jax.sharding.NamedSharding(
+            self.mesh, PartitionSpec())
+
+        self.num_agents = 2
+        self.sample_batch_size = 10
+        self.length = 20
+
+        self.buffer = experience_buffer.make_experience_buffer(
+            self.num_agents, self.sample_batch_size, self.length)
+
+    def test_shape(self):
+        """Tests that the shapes coming out are right."""
+        buffer_state = self.buffer.init({
+            'key': jax.numpy.zeros((2, )),
+        })
+
+        for i in range(self.sample_batch_size // self.num_agents):
+            buffer_state = self.buffer.add(
+                buffer_state, {
+                    'key':
+                    jax.numpy.array(
+                        [[i * 4, i * 4 + 1], [i * 4 + 2, i * 4 + 3]],
+                        dtype=jax.numpy.float32)
+                })
+
+        rng = jax.random.key(0)
+
+        for i in range(2):
+            rng, sample_rng = jax.random.split(rng)
+            batch = self.buffer.sample(buffer_state, sample_rng)
+            self.assertEqual(batch.experience['key'].shape, (2, 5, 2))
+
+    def test_randomness(self):
+        """Tests that no sample is more or less likely."""
+        rng = jax.random.key(0)
+
+        # Adds an element to the buffer, and accumulates the sample.
+        def loop(i, val):
+            counts, buffer_state, rng = val
+
+            buffer_state = self.buffer.add(
+                buffer_state,
+                {'key': jax.numpy.array([[-i], [i]], dtype=jax.numpy.int32)})
+
+            rng, sample_rng = jax.random.split(rng)
+
+            def do_count(counts):
+                batch = self.buffer.sample(buffer_state, sample_rng)
+                for a in range(self.num_agents):
+                    for s in range(5):
+                        sampled_agent = jax.numpy.abs(batch.experience['key'][
+                            a, s, 0]) % (self.length // self.num_agents)
+                        prior = counts[a, sampled_agent]
+                        counts = counts.at[a, sampled_agent].set(prior + 1)
+
+                return counts
+
+            # If we are full, start randomly picking and counting.
+            counts = jax.lax.cond(i >= self.length // self.num_agents,
+                                  do_count, lambda counts: counts, counts)
+
+            return counts, buffer_state, rng
+
+        @jax.jit
+        def doit(rng):
+            buffer_state = self.buffer.init({
+                'key':
+                jax.numpy.zeros((1, ), dtype=jax.numpy.int32),
+            })
+
+            counts = numpy.zeros(
+                (self.num_agents, self.length // self.num_agents))
+
+            counts, buffer_state, rng = jax.lax.fori_loop(
+                0,
+                10000,
+                loop,
+                (jax.numpy.zeros(
+                    (self.num_agents, self.length // self.num_agents)),
+                 buffer_state, rng),
+            )
+            return counts
+
+        # Do this all in jax to make it fast.  Many times speedup, including the JIT'ing.
+        counts = numpy.array(doit(rng), dtype=numpy.int32)
+        print(counts.min(), counts.max())
+
+        # Make sure things are decently balanced.
+        self.assertGreater(counts.min(), 4800)
+        self.assertLess(counts.max(), 5200)
+
+
+if __name__ == "__main__":
+    unittest.main()