Add flashbax support

This gives us a JAX experience buffer for Soft Actor-Critic.

Add jaxtyping in too so we can type things better.

Change-Id: I4c9f0071f9e912dcab0c883da8d9d7990ed06c46
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/tools/python/requirements.lock.txt b/tools/python/requirements.lock.txt
index 7faf08a..93171a9 100644
--- a/tools/python/requirements.lock.txt
+++ b/tools/python/requirements.lock.txt
@@ -192,7 +192,9 @@
 chex==0.1.86 \
     --hash=sha256:251c20821092323a3d9c28e1cf80e4a58180978bec368f531949bd9847eee568 \
     --hash=sha256:e8b0f96330eba4144659e1617c0f7a57b161e8cbb021e55c6d5056c7378091d1
-    # via optax
+    # via
+    #   flashbax
+    #   optax
 click==8.1.7 \
     --hash=sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28 \
     --hash=sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de
@@ -338,6 +340,10 @@
     #   optax
     #   orbax-checkpoint
     #   tensorflow-datasets
+flashbax==0.1.2 \
+    --hash=sha256:ac50b2580808ce63787da0ae240db14986e3404ade98a33335e6d7a5efe84859 \
+    --hash=sha256:b566ac5a78975b3e0a0a404a8844a26aa45e9cacfaad2829dcbcac2ffb3d5f7a
+    # via -r tools/python/requirements.txt
 flask==3.0.3 \
     --hash=sha256:34e815dfaa43340d1d15a5c3a02b8476004037eb4840b34910c6e21679d288f3 \
     --hash=sha256:ceb27b0af3823ea2737928a4d99d125a06175b8512c445cbd9a9ce200ef76842
@@ -352,6 +358,7 @@
     # via
     #   -r tools/python/requirements.txt
     #   clu
+    #   flashbax
 fonttools==4.53.1 \
     --hash=sha256:02569e9a810f9d11f4ae82c391ebc6fb5730d95a0657d24d754ed7763fb2d122 \
     --hash=sha256:0679a30b59d74b6242909945429dbddb08496935b82f91ea9bf6ad240ec23397 \
@@ -527,6 +534,7 @@
     #   -r tools/python/requirements.txt
     #   chex
     #   clu
+    #   flashbax
     #   flax
     #   optax
     #   orbax-checkpoint
@@ -568,9 +576,14 @@
     # via
     #   chex
     #   clu
+    #   flashbax
     #   jax
     #   optax
     #   orbax-checkpoint
+jaxtyping==0.2.34 \
+    --hash=sha256:2f81fb6d1586e497a6ea2d28c06dcab37b108a096cbb36ea3fe4fa2e1c1f32e5 \
+    --hash=sha256:eed9a3458ec8726c84ea5457ebde53c964f65d2c22c0ec40d0555ae3fed5bbaf
+    # via -r tools/python/requirements.txt
 jinja2==3.1.4 \
     --hash=sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369 \
     --hash=sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d
@@ -969,6 +982,7 @@
     #   clu
     #   contourpy
     #   etils
+    #   flashbax
     #   flax
     #   h5py
     #   jax
@@ -1698,6 +1712,7 @@
     --hash=sha256:cc315029f49c0f294f0721462c221e0ef4c15360a526cc34392ac81565fd63b8 \
     --hash=sha256:f47597209ce11228cfe6b94999f582788aac5571e85c3e8dcaa43b1f07660589
     # via
+    #   flashbax
     #   flax
     #   orbax-checkpoint
 termcolor==2.4.0 \
@@ -1735,6 +1750,10 @@
     --hash=sha256:90279a3770753eafc9194a0364852159802111925aa30eb3f9d85b0e805ac7cd \
     --hash=sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad
     # via tensorflow-datasets
+typeguard==2.13.3 \
+    --hash=sha256:00edaa8da3a133674796cf5ea87d9f4b4c367d77476e185e80251cc13dfbb8c4 \
+    --hash=sha256:5e3e3be01e887e7eafae5af63d1f36c849aaa94e3a0112097312aabfa16284f1
+    # via jaxtyping
 typing-extensions==4.12.2 \
     --hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \
     --hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8
@@ -1742,6 +1761,7 @@
     #   chex
     #   clu
     #   etils
+    #   flashbax
     #   flax
     #   optree
     #   orbax-checkpoint