Add flashbax support
This gives us a JAX experience buffer for Soft Actor-Critic.
Add jaxtyping in too so we can type things better.
Change-Id: I4c9f0071f9e912dcab0c883da8d9d7990ed06c46
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/tools/python/requirements.lock.txt b/tools/python/requirements.lock.txt
index 7faf08a..93171a9 100644
--- a/tools/python/requirements.lock.txt
+++ b/tools/python/requirements.lock.txt
@@ -192,7 +192,9 @@
chex==0.1.86 \
--hash=sha256:251c20821092323a3d9c28e1cf80e4a58180978bec368f531949bd9847eee568 \
--hash=sha256:e8b0f96330eba4144659e1617c0f7a57b161e8cbb021e55c6d5056c7378091d1
- # via optax
+ # via
+ # flashbax
+ # optax
click==8.1.7 \
--hash=sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28 \
--hash=sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de
@@ -338,6 +340,10 @@
# optax
# orbax-checkpoint
# tensorflow-datasets
+flashbax==0.1.2 \
+ --hash=sha256:ac50b2580808ce63787da0ae240db14986e3404ade98a33335e6d7a5efe84859 \
+ --hash=sha256:b566ac5a78975b3e0a0a404a8844a26aa45e9cacfaad2829dcbcac2ffb3d5f7a
+ # via -r tools/python/requirements.txt
flask==3.0.3 \
--hash=sha256:34e815dfaa43340d1d15a5c3a02b8476004037eb4840b34910c6e21679d288f3 \
--hash=sha256:ceb27b0af3823ea2737928a4d99d125a06175b8512c445cbd9a9ce200ef76842
@@ -352,6 +358,7 @@
# via
# -r tools/python/requirements.txt
# clu
+ # flashbax
fonttools==4.53.1 \
--hash=sha256:02569e9a810f9d11f4ae82c391ebc6fb5730d95a0657d24d754ed7763fb2d122 \
--hash=sha256:0679a30b59d74b6242909945429dbddb08496935b82f91ea9bf6ad240ec23397 \
@@ -527,6 +534,7 @@
# -r tools/python/requirements.txt
# chex
# clu
+ # flashbax
# flax
# optax
# orbax-checkpoint
@@ -568,9 +576,14 @@
# via
# chex
# clu
+ # flashbax
# jax
# optax
# orbax-checkpoint
+jaxtyping==0.2.34 \
+ --hash=sha256:2f81fb6d1586e497a6ea2d28c06dcab37b108a096cbb36ea3fe4fa2e1c1f32e5 \
+ --hash=sha256:eed9a3458ec8726c84ea5457ebde53c964f65d2c22c0ec40d0555ae3fed5bbaf
+ # via -r tools/python/requirements.txt
jinja2==3.1.4 \
--hash=sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369 \
--hash=sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d
@@ -969,6 +982,7 @@
# clu
# contourpy
# etils
+ # flashbax
# flax
# h5py
# jax
@@ -1698,6 +1712,7 @@
--hash=sha256:cc315029f49c0f294f0721462c221e0ef4c15360a526cc34392ac81565fd63b8 \
--hash=sha256:f47597209ce11228cfe6b94999f582788aac5571e85c3e8dcaa43b1f07660589
# via
+ # flashbax
# flax
# orbax-checkpoint
termcolor==2.4.0 \
@@ -1735,6 +1750,10 @@
--hash=sha256:90279a3770753eafc9194a0364852159802111925aa30eb3f9d85b0e805ac7cd \
--hash=sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad
# via tensorflow-datasets
+typeguard==2.13.3 \
+ --hash=sha256:00edaa8da3a133674796cf5ea87d9f4b4c367d77476e185e80251cc13dfbb8c4 \
+ --hash=sha256:5e3e3be01e887e7eafae5af63d1f36c849aaa94e3a0112097312aabfa16284f1
+ # via jaxtyping
typing-extensions==4.12.2 \
--hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \
--hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8
@@ -1742,6 +1761,7 @@
# chex
# clu
# etils
+ # flashbax
# flax
# optree
# orbax-checkpoint
diff --git a/tools/python/requirements.txt b/tools/python/requirements.txt
index e7f23c3..2b806f5 100644
--- a/tools/python/requirements.txt
+++ b/tools/python/requirements.txt
@@ -25,9 +25,13 @@
bokeh
tabulate
+flask
+
casadi>=3.6.6
+# ML libraries
jax[cuda12]
+jaxtyping
optax
flax
clu
@@ -38,4 +42,5 @@
tensorflow
tensorflow_datasets
-flask
+# Experience buffer for reinforcement learning
+flashbax
diff --git a/tools/python/whl_overrides.json b/tools/python/whl_overrides.json
index c3d9a89..586a9d3 100644
--- a/tools/python/whl_overrides.json
+++ b/tools/python/whl_overrides.json
@@ -63,6 +63,10 @@
"sha256": "6dc882d355e1e98a5d1a148d6323679dc47c9a5792939b9de72615aa4737eb0b",
"url": "https://software.frc971.org/Build-Dependencies/wheelhouse/etils-1.5.2-py3-none-any.whl"
},
+ "flashbax==0.1.2": {
+ "sha256": "ac50b2580808ce63787da0ae240db14986e3404ade98a33335e6d7a5efe84859",
+ "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/flashbax-0.1.2-py3-none-any.whl"
+ },
"flask==3.0.3": {
"sha256": "34e815dfaa43340d1d15a5c3a02b8476004037eb4840b34910c6e21679d288f3",
"url": "https://software.frc971.org/Build-Dependencies/wheelhouse/flask-3.0.3-py3-none-any.whl"
@@ -143,6 +147,10 @@
"sha256": "f74a6b0e09df4b5e2ee399ebb9f0e01190e26e84ccb0a758fadb516415c07f18",
"url": "https://software.frc971.org/Build-Dependencies/wheelhouse/jaxlib-0.4.30-cp39-cp39-manylinux2014_x86_64.whl"
},
+ "jaxtyping==0.2.34": {
+ "sha256": "2f81fb6d1586e497a6ea2d28c06dcab37b108a096cbb36ea3fe4fa2e1c1f32e5",
+ "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/jaxtyping-0.2.34-py3-none-any.whl"
+ },
"jinja2==3.1.4": {
"sha256": "bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d",
"url": "https://software.frc971.org/Build-Dependencies/wheelhouse/jinja2-3.1.4-py3-none-any.whl"
@@ -443,6 +451,10 @@
"sha256": "90279a3770753eafc9194a0364852159802111925aa30eb3f9d85b0e805ac7cd",
"url": "https://software.frc971.org/Build-Dependencies/wheelhouse/tqdm-4.66.5-py3-none-any.whl"
},
+ "typeguard==2.13.3": {
+ "sha256": "5e3e3be01e887e7eafae5af63d1f36c849aaa94e3a0112097312aabfa16284f1",
+ "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/typeguard-2.13.3-py3-none-any.whl"
+ },
"typing_extensions==4.12.2": {
"sha256": "04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d",
"url": "https://software.frc971.org/Build-Dependencies/wheelhouse/typing_extensions-4.12.2-py3-none-any.whl"