Add support for JAX
This opens up more exciting solvers, and better optimization than casadi
Change-Id: I824d9ec51c562f0a06b8d2ec1fc27a0e1c8a2ad9
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/tools/python/requirements.lock.txt b/tools/python/requirements.lock.txt
index 0bb13ca..4398d95 100644
--- a/tools/python/requirements.lock.txt
+++ b/tools/python/requirements.lock.txt
@@ -147,8 +147,49 @@
--hash=sha256:d5059f9f1e8e41f80e9c56c2ee58811450c31984dfa625329ffd7c0dad88a73b \
--hash=sha256:d84d17e21670ec07990e1044a99efe8d615d860fd176fc29ef5c306068fda313
# via
+ # jax
# markdown
# mkdocs
+jax[cuda12]==0.4.30 \
+ --hash=sha256:289b30ae03b52f7f4baf6ef082a9f4e3e29c1080e22d13512c5ecf02d5f1a55b \
+ --hash=sha256:94d74b5b2db0d80672b61d83f1f63ebf99d2ab7398ec12b2ca0c9d1e97afe577
+ # via -r tools/python/requirements.txt
+jax-cuda12-pjrt==0.4.30 \
+ --hash=sha256:895d0198ad99638fcaf976c47592e2a543eef79ea15fabd24a402d055390c328 \
+ --hash=sha256:c36fb1e0c236563bf3a87e70f4d1ab28a31d7cf5d722c9ede30c4172116e8bcb
+ # via jax-cuda12-plugin
+jax-cuda12-plugin[with_cuda]==0.4.30 \
+ --hash=sha256:04546765f0d21afa8a67bee3a032443ee5ac79f2d53b4137674f160c66ec99e4 \
+ --hash=sha256:14aad59e427e16e077da42c42aba58c98b7dfc586180a8df1528dfca6a9227ad \
+ --hash=sha256:8021452c32d39fcf7f639caea4529b2091bf36b10a3714b02dd3429ab39435e7 \
+ --hash=sha256:85f85e4fd70e8f0af6e97707b370ff5928ede08c5a1e22bca7abc05957d5812f \
+ --hash=sha256:a51c753b3064d8de6e769199ac44d26ce8bb616acef92b2b580d1d082ed833e8 \
+ --hash=sha256:cb8edccdce358451205f689e3536272200761c625c8e8059ab10523984cf8b61 \
+ --hash=sha256:d3f47541c9550b048c3563731259d0566c447f7aad12998f84b96c958fd2b629 \
+ --hash=sha256:d8d196241b9253ecb1144a4409b5deacbb9771624f097b2bbf025da3c7d8f4f8
+ # via jax
+jaxlib==0.4.30 \
+ --hash=sha256:0a3850e76278038e21685975a62b622bcf3708485f13125757a0561ee4512940 \
+ --hash=sha256:11602d5556e8baa2f16314c36518e9be4dfae0c2c256a361403fb29dc9dc79a4 \
+ --hash=sha256:16b2ab18ea90d2e15941bcf45de37afc2f289a029129c88c8d7aba0404dd0043 \
+ --hash=sha256:28e032c9b394ab7624d89b0d9d3bbcf4d1d71694fe8b3e09d3fe64122eda7b0c \
+ --hash=sha256:3a2e2c11c179f8851a72249ba1ae40ae817dfaee9877d23b3b8f7c6b7a012f76 \
+ --hash=sha256:3d31e01191ce8052bd611aaf16ff967d8d0ec0b63f1ea4b199020cecb248d667 \
+ --hash=sha256:4bdfda6a3c7a2b0cc0a7131009eb279e98ca4a6f25679fabb5302dd135a5e349 \
+ --hash=sha256:54987e97a22db70f3829b437b9329e4799d653634bacc8b398554d3b90c76b2a \
+ --hash=sha256:57090d33477fd0f0c99dc686274882ea75c44c7d712ae42dd2460b10f896131d \
+ --hash=sha256:7704db5962b32a2be3cc07185433cbbcc94ed90ee50c84021a3f8a1ecfd66ee3 \
+ --hash=sha256:974998cd8a78550402e6c09935c1f8d850cad9cc19ccd7488bde45b6f7f99c12 \
+ --hash=sha256:a56678b28f96b524ded6da8ef4b38e72a532356d139cfd434da804abf4234e14 \
+ --hash=sha256:b7079a5b1ab6864a7d4f2afaa963841451186d22c90f39719a3ff85735ce3915 \
+ --hash=sha256:bfb5d85b69c29c3c6e8051a0ea715ac1e532d6e54494c8d9c3813dcc00deac30 \
+ --hash=sha256:c40856e28f300938c6824ab1a615166193d6997dec946578823f6d402ad454e5 \
+ --hash=sha256:c58a8071c4e00898282118169f6a5a97eb15a79c2897858f3a732b17891c99ab \
+ --hash=sha256:d83f36ef42a403bbf7c7f2da526b34ba286988e170f4df5e58b3bb735417868c \
+ --hash=sha256:e93eb0646b41ba213252b51b0b69096b9cd1d81a35ea85c9d06663b5d11efe45 \
+ --hash=sha256:ea3a00005faafbe3c18b178d3b534208b3b4027b2be6230227e7b87ce399fc29 \
+ --hash=sha256:f74a6b0e09df4b5e2ee399ebb9f0e01190e26e84ccb0a758fadb516415c07f18
+ # via jax
jinja2==3.1.2 \
--hash=sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852 \
--hash=sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61
@@ -323,6 +364,27 @@
--hash=sha256:8947af423a6d0facf41ea1195b8e1e8c85ad94ac95ae307fe11232e0424b11c5 \
--hash=sha256:c8856a832c1e56702577023cd64cc5f84948280c1c0fcc6af4cd39006ea6aa8c
# via -r tools/python/requirements.txt
+ml-dtypes==0.4.0 \
+ --hash=sha256:03e7cda6ef164eed0abb31df69d2c00c3a5ab3e2610b6d4c42183a43329c72a5 \
+ --hash=sha256:2bb83fd064db43e67e67d021e547698af4c8d5c6190f2e9b1c53c09f6ff5531d \
+ --hash=sha256:3b67ec73a697c88c1122038e0de46520e48dc2ec876d42cf61bc5efe3c0b7675 \
+ --hash=sha256:41affb38fdfe146e3db226cf2953021184d6f0c4ffab52136613e9601706e368 \
+ --hash=sha256:43cf4356a0fe2eeac6d289018d0734e17a403bdf1fd911953c125dd0358edcc0 \
+ --hash=sha256:723af6346447268a3cf0b7356e963d80ecb5732b5279b2aa3fa4b9fc8297c85e \
+ --hash=sha256:75b4faf99d0711b81f393db36d210b4255fd419f6f790bc6c1b461f95ffb7a9e \
+ --hash=sha256:93afe37f3a879d652ec9ef1fc47612388890660a2657fbb5747256c3b818fd81 \
+ --hash=sha256:a15d96d090aebb55ee85173d1775ae325a001aab607a76c8ea0b964ccd6b5364 \
+ --hash=sha256:ad6849a2db386b38e4d54fe13eb3293464561780531a918f8ef4c8169170dd49 \
+ --hash=sha256:bdf689be7351cc3c95110c910c1b864002f113e682e44508910c849e144f3df1 \
+ --hash=sha256:c83e4d443962d891d51669ff241d5aaad10a8d3d37a81c5532a45419885d591c \
+ --hash=sha256:e1e2f4237b459a63c97c2c9f449baa637d7e4c20addff6a9bac486f22432f3b6 \
+ --hash=sha256:eaa32979ebfde3a0d7c947cafbf79edc1ec77ac05ad0780ee86c1d8df70f2259 \
+ --hash=sha256:eaf197e72f4f7176a19fe3cb8b61846b38c6757607e7bf9cd4b1d84cd3e74deb \
+ --hash=sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 \
+ --hash=sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1
+ # via
+ # jax
+ # jaxlib
mpmath==1.3.0 \
--hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \
--hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c
@@ -358,13 +420,65 @@
# bokeh
# casadi
# contourpy
+ # jax
+ # jaxlib
# matplotlib
+ # ml-dtypes
# opencv-python
+ # opt-einsum
# osqp
# pandas
# qdldl
# scipy
# shapely
+nvidia-cublas-cu12==12.6.1.4 \
+ --hash=sha256:5dd125ece5469dbdceebe2e9536ad8fc4abd38aa394a7ace42fc8a930a1e81e3 \
+ --hash=sha256:5e5d384583d72ac364064ced3dd92a5caa59a8a57568595c9f82e83d255b2481 \
+ --hash=sha256:c25ab29026a265d46c1063b5fb3cb9440f5f2eb88041c6b7c6711bcb3361789f
+ # via jax-cuda12-plugin
+nvidia-cuda-cupti-cu12==12.6.68 \
+ --hash=sha256:13408a021727de6473d138a0c5e8080b23437f761508e2b11d2530fed24f4ea0 \
+ --hash=sha256:5ad6a1fcfcb42c8628f7e547079575116d428d0cb3b4fab98362e08a9ea0b842 \
+ --hash=sha256:7487f59d73a090bf661fa8da84bad649f019a249dbac3a6cc58b039e15c28d91
+ # via jax-cuda12-plugin
+nvidia-cuda-nvcc-cu12==12.6.68 \
+ --hash=sha256:3999aa4a42ac8723c09a8aafd06bc4a6ec1a0b05c53bc96c8d6cf195e84f6935 \
+ --hash=sha256:9c0a18d76f0d1de99ba1d5fd70cffb32c0249e4abc42de9c0504e34d90ff421c \
+ --hash=sha256:d2faca18a3d5dd48865ad259262f7da43358d0940d53554026102d70c14ea2f9
+ # via jax-cuda12-plugin
+nvidia-cuda-runtime-cu12==12.6.68 \
+ --hash=sha256:3d421aa4ff608b2d8c650e0208a0fb28b4b6792a35b42bd2769d802149f85238 \
+ --hash=sha256:806b51a1dd266aac41ae09ca6142faee1686d119ced006cb9b76dfd331c75ab8 \
+ --hash=sha256:846987485889786d257f6d7bdcf7544a36452936514e20dd710527b896c0fe12
+ # via jax-cuda12-plugin
+nvidia-cudnn-cu12==9.3.0.75 \
+ --hash=sha256:9ad9c6929ebb5295eb4a1728024666d1c88283373e265a0c5c883e6f9d5cd76d \
+ --hash=sha256:c5cf7ff3415e446adf195a5b7dd2ba56cd00c3ee78bfdc566e51698931aa4b7f \
+ --hash=sha256:c819e82eed8cf564b9d37478ea4eab9e87194bb3b7f7f8098bc1f67c9b80f1b6
+ # via jax-cuda12-plugin
+nvidia-cufft-cu12==11.2.6.59 \
+ --hash=sha256:251df5b20b11bb2af6d3964ac01b85a94094222d081c90f27e8df3bf533d3257 \
+ --hash=sha256:2ea19d2101d309228daeb1045397d8e28eb3ec1ec45f226bdc12ac6e9c1c59d4 \
+ --hash=sha256:998bbd77799dc427f9c48e5d57a316a7370d231fd96121fb018b370f67fc4909
+ # via jax-cuda12-plugin
+nvidia-cusolver-cu12==11.6.4.69 \
+ --hash=sha256:07d9a1fc00049cba615ec3475eca5320943df3175b05d358d2559286bb7f1fa6 \
+ --hash=sha256:1c799e473bbd369a34490322ebf6bbf8862831e199f5d6da6868d5f6f7332fff \
+ --hash=sha256:ec0419e653587d25f399736eaf1d26a6562d8bcaeb44b1e3daef87e13b669963
+ # via jax-cuda12-plugin
+nvidia-cusparse-cu12==12.5.3.3 \
+ --hash=sha256:76030755020d3a969b40273f43b8c496c4e122ee2a01fd724cf1398421bcadd8 \
+ --hash=sha256:bfa07cb86edfd6112dbead189c182a924fd9cb3e48ae117b1ac4cd3084078bc0 \
+ --hash=sha256:c9d0ff7870672b1e0c7ffc1e47e9b87b51e38ad32ae39e05f08fc68933983a80
+ # via jax-cuda12-plugin
+nvidia-nccl-cu12==2.22.3 \
+ --hash=sha256:f9f5e03c00269dee2cd1aa57019f9a024478a74ae6e9b32d5341c849fe6f6302
+ # via jax-cuda12-plugin
+nvidia-nvjitlink-cu12==12.6.68 \
+ --hash=sha256:125a6c2a44e96386dda634e13d944e60b07a0402d391a070e8fb4104b34ea1ab \
+ --hash=sha256:a55744c98d70317c5e23db14866a8cc2b733f7324509e941fc96276f9f37801d \
+ --hash=sha256:b3fd0779845f68b92063ab1393abab1ed0a23412fc520df79a8190d098b5cd6b
+ # via jax-cuda12-plugin
opencv-python==4.6.0.66 \
--hash=sha256:0dc82a3d8630c099d2f3ac1b1aabee164e8188db54a786abb7a4e27eba309440 \
--hash=sha256:5af8ba35a4fcb8913ffb86e92403e9a656a4bff4a645d196987468f0f8947875 \
@@ -374,6 +488,10 @@
--hash=sha256:e6e448b62afc95c5b58f97e87ef84699e6607fe5c58730a03301c52496005cae \
--hash=sha256:f482e78de6e7b0b060ff994ffd859bddc3f7f382bb2019ef157b0ea8ca8712f5
# via -r tools/python/requirements.txt
+opt-einsum==3.3.0 \
+ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \
+ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549
+ # via jax
osqp==0.6.2.post8 \
--hash=sha256:02175818a0b1715ae0aab88a23678a44b269587af0ef655457042ca69a45eddd \
--hash=sha256:0a6e36151d088a9196b24fffc6b1d3a8bf79dcf9e7a5bd5f9c76c9ee1e019edf \
@@ -642,6 +760,8 @@
--hash=sha256:fbc5c05c85c1a02be77b1ff591087c83bc44579c6d2bd9fb798bb64ea5e1a027
# via
# -r tools/python/requirements.txt
+ # jax
+ # jaxlib
# osqp
# qdldl
shapely==2.0.0 \
diff --git a/tools/python/requirements.txt b/tools/python/requirements.txt
index db0ec26..09121c8 100644
--- a/tools/python/requirements.txt
+++ b/tools/python/requirements.txt
@@ -26,3 +26,5 @@
tabulate
casadi>=3.6.6
+
+jax[cuda12]
diff --git a/tools/python/runtime_binary.sh b/tools/python/runtime_binary.sh
index d251a0b..8498408 100755
--- a/tools/python/runtime_binary.sh
+++ b/tools/python/runtime_binary.sh
@@ -38,5 +38,7 @@
exit 1
fi
+export XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/lib/cuda
+
# Prevent Python from importing the host's installed packages.
exec "$PYTHON_BIN" -sS "$@"
diff --git a/tools/python/whl_overrides.json b/tools/python/whl_overrides.json
index 4de8223..5dd81b7 100644
--- a/tools/python/whl_overrides.json
+++ b/tools/python/whl_overrides.json
@@ -51,6 +51,22 @@
"sha256": "d84d17e21670ec07990e1044a99efe8d615d860fd176fc29ef5c306068fda313",
"url": "https://software.frc971.org/Build-Dependencies/wheelhouse/importlib_metadata-5.1.0-py3-none-any.whl"
},
+ "jax==0.4.30": {
+ "sha256": "289b30ae03b52f7f4baf6ef082a9f4e3e29c1080e22d13512c5ecf02d5f1a55b",
+ "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/jax-0.4.30-py3-none-any.whl"
+ },
+ "jax_cuda12_pjrt==0.4.30": {
+ "sha256": "895d0198ad99638fcaf976c47592e2a543eef79ea15fabd24a402d055390c328",
+ "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/jax_cuda12_pjrt-0.4.30-py3-none-manylinux2014_x86_64.whl"
+ },
+ "jax_cuda12_plugin==0.4.30": {
+ "sha256": "d8d196241b9253ecb1144a4409b5deacbb9771624f097b2bbf025da3c7d8f4f8",
+ "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/jax_cuda12_plugin-0.4.30-cp39-cp39-manylinux2014_x86_64.whl"
+ },
+ "jaxlib==0.4.30": {
+ "sha256": "f74a6b0e09df4b5e2ee399ebb9f0e01190e26e84ccb0a758fadb516415c07f18",
+ "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/jaxlib-0.4.30-cp39-cp39-manylinux2014_x86_64.whl"
+ },
"jinja2==3.1.2": {
"sha256": "6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61",
"url": "https://software.frc971.org/Build-Dependencies/wheelhouse/Jinja2-3.1.2-py3-none-any.whl"
@@ -79,6 +95,10 @@
"sha256": "c8856a832c1e56702577023cd64cc5f84948280c1c0fcc6af4cd39006ea6aa8c",
"url": "https://software.frc971.org/Build-Dependencies/wheelhouse/mkdocs-1.4.2-py3-none-any.whl"
},
+ "ml_dtypes==0.4.0": {
+ "sha256": "f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1",
+ "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/ml_dtypes-0.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
+ },
"mpmath==1.3.0": {
"sha256": "a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c",
"url": "https://software.frc971.org/Build-Dependencies/wheelhouse/mpmath-1.3.0-py3-none-any.whl"
@@ -87,10 +107,54 @@
"sha256": "d7806500e4f5bdd04095e849265e55de20d8cc4b661b038957354327f6d9b295",
"url": "https://software.frc971.org/Build-Dependencies/wheelhouse/numpy-1.25.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
},
+ "nvidia_cublas_cu12==12.6.1.4": {
+ "sha256": "5dd125ece5469dbdceebe2e9536ad8fc4abd38aa394a7ace42fc8a930a1e81e3",
+ "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/nvidia_cublas_cu12-12.6.1.4-py3-none-manylinux2014_x86_64.whl"
+ },
+ "nvidia_cuda_cupti_cu12==12.6.68": {
+ "sha256": "13408a021727de6473d138a0c5e8080b23437f761508e2b11d2530fed24f4ea0",
+ "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/nvidia_cuda_cupti_cu12-12.6.68-py3-none-manylinux2014_x86_64.whl"
+ },
+ "nvidia_cuda_nvcc_cu12==12.6.68": {
+ "sha256": "3999aa4a42ac8723c09a8aafd06bc4a6ec1a0b05c53bc96c8d6cf195e84f6935",
+ "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/nvidia_cuda_nvcc_cu12-12.6.68-py3-none-manylinux2014_x86_64.whl"
+ },
+ "nvidia_cuda_runtime_cu12==12.6.68": {
+ "sha256": "846987485889786d257f6d7bdcf7544a36452936514e20dd710527b896c0fe12",
+ "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/nvidia_cuda_runtime_cu12-12.6.68-py3-none-manylinux2014_x86_64.whl"
+ },
+ "nvidia_cudnn_cu12==9.3.0.75": {
+ "sha256": "9ad9c6929ebb5295eb4a1728024666d1c88283373e265a0c5c883e6f9d5cd76d",
+ "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/nvidia_cudnn_cu12-9.3.0.75-py3-none-manylinux2014_x86_64.whl"
+ },
+ "nvidia_cufft_cu12==11.2.6.59": {
+ "sha256": "251df5b20b11bb2af6d3964ac01b85a94094222d081c90f27e8df3bf533d3257",
+ "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/nvidia_cufft_cu12-11.2.6.59-py3-none-manylinux2014_x86_64.whl"
+ },
+ "nvidia_cusolver_cu12==11.6.4.69": {
+ "sha256": "ec0419e653587d25f399736eaf1d26a6562d8bcaeb44b1e3daef87e13b669963",
+ "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/nvidia_cusolver_cu12-11.6.4.69-py3-none-manylinux2014_x86_64.whl"
+ },
+ "nvidia_cusparse_cu12==12.5.3.3": {
+ "sha256": "76030755020d3a969b40273f43b8c496c4e122ee2a01fd724cf1398421bcadd8",
+ "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/nvidia_cusparse_cu12-12.5.3.3-py3-none-manylinux2014_x86_64.whl"
+ },
+ "nvidia_nccl_cu12==2.22.3": {
+ "sha256": "f9f5e03c00269dee2cd1aa57019f9a024478a74ae6e9b32d5341c849fe6f6302",
+ "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/nvidia_nccl_cu12-2.22.3-py3-none-manylinux2014_x86_64.whl"
+ },
+ "nvidia_nvjitlink_cu12==12.6.68": {
+ "sha256": "125a6c2a44e96386dda634e13d944e60b07a0402d391a070e8fb4104b34ea1ab",
+ "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/nvidia_nvjitlink_cu12-12.6.68-py3-none-manylinux2014_x86_64.whl"
+ },
"opencv_python==4.6.0.66": {
"sha256": "dbdc84a9b4ea2cbae33861652d25093944b9959279200b7ae0badd32439f74de",
"url": "https://software.frc971.org/Build-Dependencies/wheelhouse/opencv_python-4.6.0.66-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
},
+ "opt_einsum==3.3.0": {
+ "sha256": "2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147",
+ "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/opt_einsum-3.3.0-py3-none-any.whl"
+ },
"osqp==0.6.2.post8": {
"sha256": "22724b3ac4eaf17582e3ff35cb6660c026e71138f27fc21dbae4f1dc60904c64",
"url": "https://software.frc971.org/Build-Dependencies/wheelhouse/osqp-0.6.2.post8-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl"