Merge "incorporate arm into superstructure.cc by adding arm constants" into main
diff --git a/WORKSPACE b/WORKSPACE
index ad59314..6fd635e 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -702,14 +702,16 @@
cc_library(
name = 'api-cpp',
visibility = ['//visibility:public'],
- hdrs = glob(['ctre/phoenix6/**/*.hpp']),
+ hdrs = glob(['ctre/phoenix6/**/*.hpp', 'ctre/unit/**/*.h']),
includes = ["."],
- deps = ["@//third_party/allwpilib/wpimath"],
+ deps = ["@//third_party/allwpilib/wpimath",
+ "@ctre_phoenix6_tools_headers//:tools",
+ ],
)
""",
- sha256 = "1b9295ac868d619d0263f285b372a22297798e45c8fdeb83ca99154a097c5c98",
+ sha256 = "67fdd9a4bf275c69666bbc8bf38312eea8a85bf9fda3901d02af3a18136ffb3e",
urls = [
- "https://maven.ctr-electronics.com/release/com/ctre/phoenix6/api-cpp/24.1.0/api-cpp-24.1.0-headers.zip",
+ "https://maven.ctr-electronics.com/release/com/ctre/phoenix6/api-cpp/24.50.0-alpha-2/api-cpp-24.50.0-alpha-2-headers.zip",
],
)
@@ -731,9 +733,9 @@
target_compatible_with = ['@//tools/platforms/hardware:roborio'],
)
""",
- sha256 = "51f4921f8995e3e80ba347a90f6fa5c0cb7d0e438923a1a736554f263e2d5b8e",
+ sha256 = "00da14f437cfeb2c344674dee59fe70477a3a0d612eb93a1441188dc94a5136f",
urls = [
- "https://maven.ctr-electronics.com/release/com/ctre/phoenix6/api-cpp/24.1.0/api-cpp-24.1.0-linuxathena.zip",
+ "https://maven.ctr-electronics.com/release/com/ctre/phoenix6/api-cpp/24.50.0-alpha-2/api-cpp-24.50.0-alpha-2-linuxathena.zip",
],
)
@@ -743,12 +745,12 @@
cc_library(
name = 'tools',
visibility = ['//visibility:public'],
- hdrs = glob(['ctre/**/*.h', 'ctre/**/*.hpp']),
+ hdrs = glob(['ctre/**/*.h', 'ctre/phoenix/**/*.hpp', 'ctre/phoenix6/**/*.hpp']),
)
""",
- sha256 = "72be3e50205e2546736361e3aba9c0de5d5318b263c195600f9b558c3d92cb9e",
+ sha256 = "77624291ec19a03c9e068347ef800e435782444722793d56c9e39f6108da33a8",
urls = [
- "https://maven.ctr-electronics.com/release/com/ctre/phoenix6/tools/24.1.0/tools-24.1.0-headers.zip",
+ "https://maven.ctr-electronics.com/release/com/ctre/phoenix6/tools/24.50.0-alpha-2/tools-24.50.0-alpha-2-headers.zip",
],
)
@@ -770,9 +772,9 @@
target_compatible_with = ['@//tools/platforms/hardware:roborio'],
)
""",
- sha256 = "76d30cb8c0988eb6accab10e77ceae492782edb64c06f7df628b8c161cf0220a",
+ sha256 = "217bcd6aecb71224fd32725f857da58e61a6ea451ca07f85fb55e49f83dbf113",
urls = [
- "https://maven.ctr-electronics.com/release/com/ctre/phoenix6/tools/24.1.0/tools-24.1.0-linuxathena.zip",
+ "https://maven.ctr-electronics.com/release/com/ctre/phoenix6/tools/24.50.0-alpha-2/tools-24.50.0-alpha-2-linuxathena.zip",
],
)
@@ -782,12 +784,13 @@
cc_library(
name = 'api-cpp',
visibility = ['//visibility:public'],
- hdrs = glob(['ctre/phoenix/**/*.h']),
+ hdrs = glob(['ctre/phoenix/**/*.h', 'ctre/unit/**/*.h']),
+ includes = ["."]
)
""",
- sha256 = "d2790e2495a59a3b14ab4b0fd03c2e25292d6874168f5642165acb0d1bab2f99",
+ sha256 = "4fba20441ea1d61f8487897682c723a48ee2c2741946eab4062b4248434c0afc",
urls = [
- "https://maven.ctr-electronics.com/release/com/ctre/phoenix/api-cpp/5.33.0/api-cpp-5.33.0-headers.zip",
+ "https://maven.ctr-electronics.com/release/com/ctre/phoenix/api-cpp/5.34.0-alpha-1/api-cpp-5.34.0-alpha-1-headers.zip",
],
)
@@ -824,9 +827,9 @@
hdrs = glob(['ctre/phoenix/**/*.h']),
)
""",
- sha256 = "56cb8272d623721560eea360730d2508f4f365e2ac1060aec5fbd546cffc0771",
+ sha256 = "cd39f32341037a7ec1074710044b332db6ca607dfd548b5f951c75f7a8506ab5",
urls = [
- "https://maven.ctr-electronics.com/release/com/ctre/phoenix/cci/5.33.0/cci-5.33.0-headers.zip",
+ "https://maven.ctr-electronics.com/release/com/ctre/phoenix/cci/5.34.0-alpha-1/cci-5.34.0-alpha-1-headers.zip",
],
)
@@ -867,6 +870,7 @@
target_compatible_with = ['@platforms//cpu:arm64'],
)
+
# TODO(max): Use cc_import once they add a defines property.
# See: https://github.com/bazelbuild/bazel/issues/19753
cc_library(
@@ -891,9 +895,9 @@
],
)
""",
- sha256 = "635b19f019d1283749fb23a95ad9e13ddd9d5013cff4c838303d898e7d9556bd",
+ sha256 = "0f1312f39eacc490fb253198c2d0e61e48ae00eff6a87cfd362358b1ad36a930",
urls = [
- "https://software.frc971.org/Build-Dependencies/phoenix6_24.2.0_5.4.2024.tar.gz",
+ "https://software.frc971.org/Build-Dependencies/phoenix6_24.50.0-alpha-2_arm64-2024.10.26.tar.gz",
],
)
diff --git a/aos/containers/error_list.h b/aos/containers/error_list.h
index 7ceb1fe..762d2ad 100644
--- a/aos/containers/error_list.h
+++ b/aos/containers/error_list.h
@@ -5,6 +5,7 @@
#include <algorithm>
+#include "absl/log/check.h"
#include "flatbuffers/buffer.h"
#include "flatbuffers/flatbuffer_builder.h"
#include "flatbuffers/vector.h"
@@ -128,7 +129,11 @@
flatbuffers::FlatBufferBuilder *fbb) const {
return fbb->CreateVector(array_.data(), array_.size());
}
-}; // namespace aos
+ template <typename StaticBuilder>
+ void ToStaticFlatbuffer(StaticBuilder *vector_builder) const {
+ CHECK(vector_builder->FromData(array_.data(), array_.size()));
+ }
+};
} // namespace aos
diff --git a/aos/events/logging/config_remapper.cc b/aos/events/logging/config_remapper.cc
index 763cbc3..edc2326 100644
--- a/aos/events/logging/config_remapper.cc
+++ b/aos/events/logging/config_remapper.cc
@@ -566,7 +566,9 @@
// Add the schema if it doesn't exist.
if (schema_map.find(c->type()->string_view()) == schema_map.end()) {
- CHECK(c->has_schema());
+ if (!c->has_schema()) {
+ LOG(FATAL) << "Could not find schema for " << c->type()->string_view();
+ }
schema_map.insert(std::make_pair(c->type()->string_view(),
RecursiveCopyFlatBuffer(c->schema())));
}
diff --git a/aos/network/team_number.cc b/aos/network/team_number.cc
index 175c3a9..bd90b96 100644
--- a/aos/network/team_number.cc
+++ b/aos/network/team_number.cc
@@ -130,6 +130,8 @@
return std::string_view("pi");
} else if (hostname.substr(0, 5) == "orin-") {
return std::string_view("orin");
+ } else if (hostname.substr(0, 4) == "imu-") {
+ return std::string_view("orin");
} else
return std::nullopt;
}
diff --git a/aos/network/team_number_test.cc b/aos/network/team_number_test.cc
index 9922b1b..278a364 100644
--- a/aos/network/team_number_test.cc
+++ b/aos/network/team_number_test.cc
@@ -28,7 +28,7 @@
EXPECT_FALSE(ParseRoborioTeamNumber("roboRIO--FRC"));
}
-TEST(TeamNumberTest, ParsePiOrOrinTeamNumber) {
+TEST(HostnameParseTest, ParsePiOrOrinTeamNumber) {
EXPECT_EQ(971u, *ParsePiOrOrinTeamNumber("pi-971-1"));
EXPECT_EQ(8971u, *ParsePiOrOrinTeamNumber("pi-8971-22"));
EXPECT_EQ(8971u, *ParsePiOrOrinTeamNumber("pi-8971-"));
@@ -37,12 +37,16 @@
EXPECT_EQ(8971u, *ParsePiOrOrinTeamNumber("orin-8971-22"));
EXPECT_EQ(8971u, *ParsePiOrOrinTeamNumber("orin-8971-"));
+ EXPECT_FALSE(ParsePiOrOrinTeamNumber("roboRIO-971-FRC"));
+
EXPECT_FALSE(ParseRoborioTeamNumber("pi"));
EXPECT_FALSE(ParseRoborioTeamNumber("pi-"));
EXPECT_FALSE(ParseRoborioTeamNumber("pi-971"));
EXPECT_FALSE(ParseRoborioTeamNumber("pi-971a-1"));
EXPECT_FALSE(ParseRoborioTeamNumber("orin-971-1"));
+}
+TEST(HostnameParseTest, ParsePiOrOrinNumber) {
EXPECT_EQ(1u, *ParsePiOrOrinNumber("pi-971-1"));
EXPECT_EQ(22u, *ParsePiOrOrinNumber("pi-8971-22"));
EXPECT_EQ(1u, *ParsePiOrOrinNumber("orin-971-1"));
@@ -59,4 +63,19 @@
EXPECT_FALSE(ParsePiOrOrinNumber("orin-971"));
}
+TEST(HostnameParseTest, ParsePiOrOrin) {
+ EXPECT_EQ("pi", *ParsePiOrOrin("pi-971-1"));
+ EXPECT_EQ("pi", *ParsePiOrOrin("pi-8971-22"));
+ EXPECT_EQ("pi", *ParsePiOrOrin("pi-8971-"));
+
+ EXPECT_EQ("orin", *ParsePiOrOrin("orin-971-1"));
+ EXPECT_EQ("orin", *ParsePiOrOrin("orin-8971-22"));
+ EXPECT_EQ("orin", *ParsePiOrOrin("orin-8971-"));
+
+ EXPECT_EQ("orin", *ParsePiOrOrin("imu-971-1"));
+
+ EXPECT_FALSE(ParsePiOrOrin("roboRIO-971-FRC"));
+ EXPECT_FALSE(ParsePiOrOrin("laptop"));
+}
+
} // namespace aos::network::testing
diff --git a/frc971/control_loops/dlqr.h b/frc971/control_loops/dlqr.h
index 57aa88b..d31f492 100644
--- a/frc971/control_loops/dlqr.h
+++ b/frc971/control_loops/dlqr.h
@@ -41,7 +41,9 @@
::Eigen::Matrix<double, kM, kN> *K,
::Eigen::Matrix<double, kN, kN> *S) {
*K = ::Eigen::Matrix<double, kM, kN>::Zero();
- *S = ::Eigen::Matrix<double, kN, kN>::Zero();
+ if (S != nullptr) {
+ *S = ::Eigen::Matrix<double, kN, kN>::Zero();
+ }
// Discrete (not continuous)
char DICO = 'D';
// B and R are provided instead of G.
diff --git a/frc971/control_loops/swerve/BUILD b/frc971/control_loops/swerve/BUILD
index 2c3ad50..6be2604 100644
--- a/frc971/control_loops/swerve/BUILD
+++ b/frc971/control_loops/swerve/BUILD
@@ -14,8 +14,17 @@
)
static_flatbuffer(
+ name = "swerve_drivetrain_joystick_goal_fbs",
+ srcs = ["swerve_drivetrain_joystick_goal.fbs"],
+)
+
+static_flatbuffer(
name = "swerve_drivetrain_goal_fbs",
srcs = ["swerve_drivetrain_goal.fbs"],
+ deps = [
+ ":swerve_drivetrain_joystick_goal_fbs",
+ "//frc971/math:matrix_fbs",
+ ],
)
static_flatbuffer(
@@ -292,6 +301,33 @@
],
)
+py_binary(
+ name = "experience_collector",
+ srcs = [
+ "experience_collector.py",
+ ],
+ deps = [
+ ":casadi_velocity_mpc_lib",
+ ":jax_dynamics",
+ "//frc971/control_loops/swerve/velocity_controller:physics",
+ "@pip//absl_py",
+ "@pip//matplotlib",
+ "@pip//numpy",
+ "@pip//pygobject",
+ "@pip//scipy",
+ "@pip//tensorflow",
+ ],
+)
+
+py_binary(
+ name = "multi_experience_collector",
+ srcs = ["multi_experience_collector.py"],
+ data = [":experience_collector"],
+ deps = [
+ "@pip//absl_py",
+ ],
+)
+
py_library(
name = "physics_test_utils",
srcs = [
@@ -371,3 +407,92 @@
"@com_google_absl//absl/log:check",
],
)
+
+cc_library(
+ name = "linearization_utils",
+ hdrs = ["linearization_utils.h"],
+)
+
+cc_library(
+ name = "linearized_controller",
+ hdrs = ["linearized_controller.h"],
+ deps = [
+ ":eigen_dynamics",
+ ":linearization_utils",
+ "//frc971/control_loops:c2d",
+ "//frc971/control_loops:dlqr",
+ "//frc971/control_loops:jacobian",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ ],
+)
+
+cc_test(
+ name = "linearized_controller_test",
+ srcs = ["linearized_controller_test.cc"],
+ deps = [
+ ":linearized_controller",
+ "//aos/testing:googletest",
+ ],
+)
+
+cc_library(
+ name = "auto_diff_jacobian",
+ hdrs = ["auto_diff_jacobian.h"],
+ target_compatible_with = ["@platforms//os:linux"],
+ deps = [
+ "@com_google_ceres_solver//:ceres",
+ ],
+)
+
+cc_test(
+ name = "auto_diff_jacobian_test",
+ srcs = ["auto_diff_jacobian_test.cc"],
+ deps = [
+ ":auto_diff_jacobian",
+ "//aos/testing:googletest",
+ ],
+)
+
+cc_library(
+ name = "simplified_dynamics",
+ hdrs = ["simplified_dynamics.h"],
+ deps = [
+ ":auto_diff_jacobian",
+ ":eigen_dynamics",
+ ":motors",
+ "//aos/util:math",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ ],
+)
+
+cc_test(
+ name = "simplified_dynamics_test",
+ srcs = ["simplified_dynamics_test.cc"],
+ deps = [
+ ":simplified_dynamics",
+ "//aos/testing:googletest",
+ "//aos/time",
+ "//frc971/control_loops:jacobian",
+ "@com_google_absl//absl/log",
+ ],
+)
+
+cc_library(
+ name = "inverse_kinematics",
+ hdrs = ["inverse_kinematics.h"],
+ deps = [
+ ":simplified_dynamics",
+ "//aos/util:math",
+ ],
+)
+
+cc_test(
+ name = "inverse_kinematics_test",
+ srcs = ["inverse_kinematics_test.cc"],
+ deps = [
+ ":inverse_kinematics",
+ "//aos/testing:googletest",
+ ],
+)
diff --git a/frc971/control_loops/swerve/auto_diff_jacobian.h b/frc971/control_loops/swerve/auto_diff_jacobian.h
new file mode 100644
index 0000000..2e7947a
--- /dev/null
+++ b/frc971/control_loops/swerve/auto_diff_jacobian.h
@@ -0,0 +1,62 @@
+#ifndef FRC971_CONTROL_LOOPS_SWERVE_AUTO_DIFF_JACOBIAN_H_
+#define FRC971_CONTROL_LOOPS_SWERVE_AUTO_DIFF_JACOBIAN_H_
+#include "include/ceres/tiny_solver.h"
+#include "include/ceres/tiny_solver_autodiff_function.h"
+
+namespace frc971::control_loops::swerve {
+// Class to conveniently scope a function that makes use of Ceres'
+// autodifferentiation methods to calculate the jacobian of the provided method.
+// Template parameters:
+// Scalar: scalar type to use (typically double or float; this is used for
+// allowing you to control what precision you use; you cannot use this
+// with ceres Jets because this has to use Jets internally).
+// Function: The type of the function itself. A Function f must be callable as
+// Eigen::Matrix<Scalar, kNumOutputs, 1> = f(Eigen::Matrix<Scalar,
+// kNumInputs, 1>{});
+template <typename Scalar, typename Function, size_t kNumInputs,
+ size_t kNumOutputs>
+class AutoDiffJacobian {
+ public:
+ // Calculates the jacobian of the provided method, function, at the provided
+ // input X.
+ static Eigen::Matrix<Scalar, kNumOutputs, kNumInputs> Jacobian(
+ const Function &function, const Eigen::Matrix<Scalar, kNumInputs, 1> &X) {
+ AutoDiffCeresFunctor ceres_functor(function);
+ TinySolverFunctor tiny_solver(ceres_functor);
+ // residual is unused, it's just a place to store the evaluated function at
+ // the current state/input.
+ Eigen::Matrix<Scalar, kNumOutputs, 1> residual;
+ Eigen::Matrix<Scalar, kNumOutputs, kNumInputs> jacobian;
+ tiny_solver(X.data(), residual.data(), jacobian.data());
+ return jacobian;
+ }
+
+ private:
+ // Borrow the TinySolver's auto-differentiation execution for use here to
+ // calculate the linearized dynamics. We aren't actually doing any solving,
+ // just letting it do the jacobian calculation for us.
+ // As such, construct a "residual" function whose residuals are just the
+ // derivative of the state and whose parameters are the stacked state + input.
+ class AutoDiffCeresFunctor {
+ public:
+ AutoDiffCeresFunctor(const Function &function) : function_(function) {}
+ template <typename ScalarT>
+ bool operator()(const ScalarT *const parameters,
+ ScalarT *const residuals) const {
+ const Eigen::Map<const Eigen::Matrix<ScalarT, kNumInputs, 1>>
+ eigen_parameters(parameters);
+ Eigen::Map<Eigen::Matrix<ScalarT, kNumOutputs, 1>> eigen_residuals(
+ residuals);
+ eigen_residuals = function_(eigen_parameters);
+ return true;
+ }
+
+ private:
+ const Function &function_;
+ };
+ typedef ceres::TinySolverAutoDiffFunction<AutoDiffCeresFunctor, kNumOutputs,
+ kNumInputs, Scalar>
+ TinySolverFunctor;
+};
+} // namespace frc971::control_loops::swerve
+#endif // FRC971_CONTROL_LOOPS_SWERVE_AUTO_DIFF_JACOBIAN_H_
diff --git a/frc971/control_loops/swerve/auto_diff_jacobian_test.cc b/frc971/control_loops/swerve/auto_diff_jacobian_test.cc
new file mode 100644
index 0000000..d04726c
--- /dev/null
+++ b/frc971/control_loops/swerve/auto_diff_jacobian_test.cc
@@ -0,0 +1,23 @@
+#include "frc971/control_loops/swerve/auto_diff_jacobian.h"
+
+#include <functional>
+
+#include "absl/log/log.h"
+#include "gtest/gtest.h"
+
+namespace frc971::control_loops::swerve::testing {
+struct TestFunction {
+ template <typename Scalar>
+ Eigen::Matrix<Scalar, 3, 1> operator()(
+ const Eigen::Map<const Eigen::Matrix<Scalar, 2, 1>> X) const {
+ return Eigen::Matrix<double, 3, 2>{{1, 2}, {3, 4}, {5, 6}} * X;
+ }
+};
+
+TEST(AutoDiffJacobianTest, EvaluatesJacobian) {
+ EXPECT_EQ((AutoDiffJacobian<double, TestFunction, 2, 3>::Jacobian(
+ TestFunction{}, Eigen::Vector2d::Zero())),
+ (Eigen::Matrix<double, 3, 2>{{1, 2}, {3, 4}, {5, 6}}));
+}
+
+} // namespace frc971::control_loops::swerve::testing
diff --git a/frc971/control_loops/swerve/casadi_velocity_mpc.py b/frc971/control_loops/swerve/casadi_velocity_mpc.py
index d62f7ed..54e33f1 100644
--- a/frc971/control_loops/swerve/casadi_velocity_mpc.py
+++ b/frc971/control_loops/swerve/casadi_velocity_mpc.py
@@ -24,11 +24,6 @@
flags.DEFINE_bool('pickle', False, 'Write optimization results.')
flags.DEFINE_string('outputdir', None, 'Directory to write problem results to')
-# Full print level on ipopt. Piping to a file and using a filter or search method is suggested
-# grad_x prints out the gradient at each iteration in the following sequence: U0, X1, U1, etc.
-flags.DEFINE_bool('full_debug', False,
- 'If true, turn on all the debugging in the solver.')
-
class Solver(object):
diff --git a/frc971/control_loops/swerve/casadi_velocity_mpc_lib.py b/frc971/control_loops/swerve/casadi_velocity_mpc_lib.py
index f464268..e358422 100644
--- a/frc971/control_loops/swerve/casadi_velocity_mpc_lib.py
+++ b/frc971/control_loops/swerve/casadi_velocity_mpc_lib.py
@@ -3,11 +3,19 @@
from frc971.control_loops.swerve import dynamics
import casadi
import numpy
+from absl import flags
+
+FLAGS = flags.FLAGS
+
+# Full print level on ipopt. Piping to a file and using a filter or search method is suggested
+# grad_x prints out the gradient at each iteration in the following sequence: U0, X1, U1, etc.
+flags.DEFINE_bool('full_debug', False,
+ 'If true, turn on all the debugging in the solver.')
class MPC(object):
- def __init__(self, solver='fatrop', jit=True):
+ def __init__(self, solver='fatrop', jit=True, N=200):
self.fdot = dynamics.swerve_full_dynamics(
casadi.SX.sym("X", dynamics.NUM_STATES, 1),
casadi.SX.sym("U", 8, 1))
@@ -47,7 +55,7 @@
self.next_X = self.make_physics()
self.cost = self.make_cost()
- self.N = 200
+ self.N = N
# Start with an empty nonlinear program.
self.w = []
diff --git a/frc971/control_loops/swerve/experience_collector.py b/frc971/control_loops/swerve/experience_collector.py
new file mode 100644
index 0000000..b4c17fb
--- /dev/null
+++ b/frc971/control_loops/swerve/experience_collector.py
@@ -0,0 +1,201 @@
+#!/usr/bin/env python3
+import os, sys
+
+# Setup XLA first.
+os.environ['XLA_FLAGS'] = ' '.join([
+ # Teach it where to find CUDA
+ '--xla_gpu_cuda_data_dir=/usr/lib/cuda',
+ # Use up to 20 cores
+ #'--xla_force_host_platform_device_count=6',
+ # Dump XLA to /tmp/foo to aid debugging
+ #'--xla_dump_to=/tmp/foo',
+ #'--xla_gpu_enable_command_buffer='
+ # Dump sharding
+ #"--xla_dump_to=/tmp/foo",
+ #"--xla_dump_hlo_pass_re=spmd|propagation"
+])
+os.environ['JAX_PLATFORMS'] = 'cpu'
+os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
+os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.50'
+
+from absl import flags
+from absl import app
+from absl import logging
+import pickle
+import numpy
+from frc971.control_loops.swerve import dynamics
+from frc971.control_loops.swerve.casadi_velocity_mpc_lib import MPC
+import jax
+import tensorflow as tf
+from frc971.control_loops.swerve.velocity_controller.physics import SwerveProblem
+from frc971.control_loops.swerve import jax_dynamics
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_bool('compileonly', False,
+ 'If true, load casadi, don\'t compile it')
+
+flags.DEFINE_float('vx', 1.0, 'Goal velocity in m/s in x')
+flags.DEFINE_float('vy', 0.0, 'Goal velocity in m/s in y')
+flags.DEFINE_float('omega', 0.0, 'Goal velocity in m/s in omega')
+flags.DEFINE_integer('seed', 0, 'Seed for random initial state.')
+
+flags.DEFINE_bool('save_plots', True,
+ 'If true, save plots for each run as well.')
+flags.DEFINE_string('outputdir', None, 'Directory to write problem results to')
+flags.DEFINE_bool('quiet', False, 'If true, print a lot less')
+
+flags.DEFINE_integer('num_solutions', 100,
+ 'Number of random problems to solve.')
+flags.DEFINE_integer('horizon', 200, 'Horizon to solve for')
+
+try:
+ from matplotlib import pylab
+except ModuleNotFoundError:
+ pass
+
+
+def collect_experience(problem, mpc, rng):
+ X_initial = numpy.array(problem.random_states(rng,
+ dimensions=1)).transpose()
+
+ R_goal = numpy.zeros((3, 1))
+ R_goal[0, 0] = FLAGS.vx
+ R_goal[1, 0] = FLAGS.vy
+ R_goal[2, 0] = FLAGS.omega
+
+ solution = mpc.solve(p=numpy.vstack((X_initial, R_goal)))
+ sys.stderr.flush()
+ sys.stdout.flush()
+
+ # Solver doesn't solve for the last state. So we get N-1 states back.
+ experience = {
+ 'observations1': numpy.zeros((mpc.N - 1, problem.num_states)),
+ 'observations2': numpy.zeros((mpc.N - 1, problem.num_states)),
+ 'actions': numpy.zeros((mpc.N - 1, problem.num_outputs)),
+ 'rewards': numpy.zeros((mpc.N - 1, 1)),
+ 'goals': numpy.zeros((mpc.N - 1, problem.num_goals)),
+ }
+
+ if not FLAGS.quiet:
+ print('x(0):', X_initial.transpose())
+
+ logging.info('Finished solving')
+ X_prior = X_initial.squeeze()
+ for j in range(mpc.N - 1):
+ if not FLAGS.quiet:
+ print(f'u({j}): ', mpc.unpack_u(solution, j))
+ print(f'x({j+1}): ', mpc.unpack_x(solution, j + 1))
+ experience['observations1'][j, :] = X_prior
+ X_prior = mpc.unpack_x(solution, j + 1)
+ experience['observations2'][j, :] = X_prior
+ experience['actions'][j, :] = mpc.unpack_u(solution, j)
+ experience['goals'][j, :] = R_goal[:, 0]
+
+ logging.info('Finished all but reward')
+ for j in range(mpc.N - 1):
+ experience['rewards'][j, :] = problem.reward(
+ X=X_prior,
+ U=mpc.unpack_u(solution, j),
+ goal=R_goal[:, 0],
+ )
+ sys.stderr.flush()
+ sys.stdout.flush()
+
+ return experience
+
+
+def save_experience(problem, mpc, experience, experience_number):
+ with open(f'experience_{experience_number}.pkl', 'wb') as f:
+ pickle.dump(experience, f)
+
+ if not FLAGS.save_plots:
+ return
+
+ fig0, axs0 = pylab.subplots(3)
+ fig1, axs1 = pylab.subplots(2)
+
+ axs0[0].clear()
+ axs0[1].clear()
+ axs0[2].clear()
+
+ t = problem.dt * numpy.array(list(range(mpc.N - 1)))
+
+ X_plot = experience['observations1']
+ U_plot = experience['actions']
+
+ axs0[0].plot(t, X_plot[:, dynamics.VELOCITY_STATE_VX], label="vx")
+ axs0[0].plot(t, X_plot[:, dynamics.VELOCITY_STATE_VY], label="vy")
+ axs0[0].plot(t, X_plot[:, dynamics.VELOCITY_STATE_OMEGA], label="omega")
+ axs0[0].legend()
+
+ axs0[1].plot(t, U_plot[:, 0], label="Is0")
+ axs0[1].plot(t, U_plot[:, 2], label="Is1")
+ axs0[1].plot(t, U_plot[:, 4], label="Is2")
+ axs0[1].plot(t, U_plot[:, 6], label="Is3")
+ axs0[1].legend()
+
+ axs0[2].plot(t, U_plot[:, 1], label="Id0")
+ axs0[2].plot(t, U_plot[:, 3], label="Id1")
+ axs0[2].plot(t, U_plot[:, 5], label="Id2")
+ axs0[2].plot(t, U_plot[:, 7], label="Id3")
+ axs0[2].legend()
+
+ axs1[0].clear()
+ axs1[1].clear()
+
+ axs1[0].plot(t, X_plot[:, dynamics.VELOCITY_STATE_THETAS0], label='steer0')
+ axs1[0].plot(t, X_plot[:, dynamics.VELOCITY_STATE_THETAS1], label='steer1')
+ axs1[0].plot(t, X_plot[:, dynamics.VELOCITY_STATE_THETAS2], label='steer2')
+ axs1[0].plot(t, X_plot[:, dynamics.VELOCITY_STATE_THETAS3], label='steer3')
+ axs1[0].legend()
+ axs1[1].plot(t,
+ X_plot[:, dynamics.VELOCITY_STATE_OMEGAS0],
+ label='steer_velocity0')
+ axs1[1].plot(t,
+ X_plot[:, dynamics.VELOCITY_STATE_OMEGAS1],
+ label='steer_velocity1')
+ axs1[1].plot(t,
+ X_plot[:, dynamics.VELOCITY_STATE_OMEGAS2],
+ label='steer_velocity2')
+ axs1[1].plot(t,
+ X_plot[:, dynamics.VELOCITY_STATE_OMEGAS3],
+ label='steer_velocity3')
+ axs1[1].legend()
+
+ fig0.savefig(f'state_{experience_number}.svg')
+ fig1.savefig(f'steer_{experience_number}.svg')
+
+ # Free the memory associated with the figures.
+ fig0.clf()
+ fig1.clf()
+ pylab.close(fig0)
+ pylab.close(fig1)
+
+
+def main(argv):
+ if FLAGS.outputdir:
+ os.chdir(FLAGS.outputdir)
+
+ # Hide any GPUs from TensorFlow. Otherwise it might reserve memory.
+ tf.config.experimental.set_visible_devices([], 'GPU')
+ rng = jax.random.key(FLAGS.seed)
+
+ physics_constants = jax_dynamics.Coefficients()
+ problem = SwerveProblem(physics_constants)
+ mpc = MPC(solver='ipopt', N=(FLAGS.horizon + 1))
+
+ if FLAGS.compileonly:
+ return
+
+ for i in range(FLAGS.num_solutions):
+ rng, rng_init = jax.random.split(rng)
+ experience = collect_experience(problem, mpc, rng_init)
+ logging.info('Solved problem %d', i)
+
+ save_experience(problem, mpc, experience, i)
+ logging.info('Wrote problem %d', i)
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/frc971/control_loops/swerve/generate_physics.cc b/frc971/control_loops/swerve/generate_physics.cc
index de52d79..8a7fe16 100644
--- a/frc971/control_loops/swerve/generate_physics.cc
+++ b/frc971/control_loops/swerve/generate_physics.cc
@@ -339,7 +339,7 @@
result_h.emplace_back(
"inline constexpr size_t kNumVelocityStates = "
"static_cast<size_t>(VelocityStates::kNumStates);");
- result_h.emplace_back("struct Inputs {");
+ result_h.emplace_back("struct InputStates {");
result_h.emplace_back("enum States {");
result_h.emplace_back(" kIs0 = 0,");
result_h.emplace_back(" kId0 = 1,");
@@ -354,7 +354,7 @@
result_h.emplace_back("};");
result_h.emplace_back(
"inline constexpr size_t kNumInputs = "
- "static_cast<size_t>(Inputs::kNumInputs);");
+ "static_cast<size_t>(InputStates::kNumInputs);");
result_h.emplace_back("");
result_h.emplace_back("// Returns the derivative of our state vector");
result_h.emplace_back(
@@ -365,29 +365,6 @@
result_h.emplace_back(
" Eigen::Ref<const Eigen::Matrix<double, kNumInputs, 1>> U);");
result_h.emplace_back("");
- result_h.emplace_back(
- "Eigen::Matrix<double, kNumVelocityStates, 1> ToVelocityState(");
- result_h.emplace_back(
- " Eigen::Ref<const Eigen::Matrix<double, kNumFullDynamicsStates, "
- "1>> X);");
- result_h.emplace_back("");
- result_h.emplace_back(
- "Eigen::Matrix<double, kNumFullDynamicsStates, 1> FromVelocityState(");
- result_h.emplace_back(
- " Eigen::Ref<const Eigen::Matrix<double, kNumVelocityStates, 1>> "
- "X);");
- result_h.emplace_back("");
- result_h.emplace_back(
- "inline Eigen::Matrix<double, kNumVelocityStates, 1> VelocityPhysics(");
- result_h.emplace_back(
- " Eigen::Ref<const Eigen::Matrix<double, kNumVelocityStates, 1>> "
- "X,");
- result_h.emplace_back(
- " Eigen::Ref<const Eigen::Matrix<double, kNumInputs, 1>> U) {");
- result_h.emplace_back(
- " return ToVelocityState(SwervePhysics(FromVelocityState(X), U));");
- result_h.emplace_back("}");
- result_h.emplace_back("");
result_h.emplace_back("} // namespace frc971::control_loops::swerve");
result_h.emplace_back("");
result_h.emplace_back(absl::Substitute("#endif // $0_", include_guard));
@@ -401,40 +378,6 @@
result_cc.emplace_back("namespace frc971::control_loops::swerve {");
result_cc.emplace_back("");
result_cc.emplace_back(
- "Eigen::Matrix<double, kNumVelocityStates, 1> ToVelocityState(");
- result_cc.emplace_back(
- " Eigen::Ref<const Eigen::Matrix<double, kNumFullDynamicsStates, "
- "1>> X) {");
- result_cc.emplace_back(
- " Eigen::Matrix<double, kNumVelocityStates, 1> velocity;");
- const std::vector<std::string_view> velocity_states = {
- "kThetas0", "kOmegas0", "kThetas1", "kOmegas1", "kThetas2", "kOmegas2",
- "kThetas3", "kOmegas3", "kTheta", "kVx", "kVy", "kOmega"};
- for (const std::string_view velocity_state : velocity_states) {
- result_cc.emplace_back(absl::StrFormat(
- " velocity(VelocityStates::%s) = X(FullDynamicsStates::%s);",
- velocity_state, velocity_state));
- }
- result_cc.emplace_back(" return velocity;");
- result_cc.emplace_back("}");
- result_cc.emplace_back("");
- result_cc.emplace_back(
- "Eigen::Matrix<double, kNumFullDynamicsStates, 1> FromVelocityState(");
- result_cc.emplace_back(
- " Eigen::Ref<const Eigen::Matrix<double, kNumVelocityStates, 1>> X) "
- "{");
- result_cc.emplace_back(
- " Eigen::Matrix<double, kNumFullDynamicsStates, 1> full;");
- result_cc.emplace_back(" full.setZero();");
- for (const std::string_view velocity_state : velocity_states) {
- result_cc.emplace_back(absl::StrFormat(
- " full(FullDynamicsStates::%s) = X(VelocityStates::%s);",
- velocity_state, velocity_state));
- }
- result_cc.emplace_back(" return full;");
- result_cc.emplace_back("}");
- result_cc.emplace_back("");
- result_cc.emplace_back(
"Eigen::Matrix<double, kNumFullDynamicsStates, 1> SwervePhysics(");
result_cc.emplace_back(
" Eigen::Ref<const Eigen::Matrix<double, kNumFullDynamicsStates, "
diff --git a/frc971/control_loops/swerve/inverse_kinematics.h b/frc971/control_loops/swerve/inverse_kinematics.h
new file mode 100644
index 0000000..a4e43e1
--- /dev/null
+++ b/frc971/control_loops/swerve/inverse_kinematics.h
@@ -0,0 +1,86 @@
+#ifndef FRC971_CONTROL_LOOPS_SWERVE_INVERSE_KINEMATICS_H_
+#define FRC971_CONTROL_LOOPS_SWERVE_INVERSE_KINEMATICS_H_
+#include "aos/util/math.h"
+#include "frc971/control_loops/swerve/simplified_dynamics.h"
+namespace frc971::control_loops::swerve {
+// Class to do straightforwards inverse kinematics of a swerve drivebase. This
+// is meant largely as a sanity-check/initializer for more sophisticated
+// methods. This calculates which directions the modules must be pointed to
+// cause them to be pointed directly along the direction of motion of the
+// drivebase. Accounting for slip angles and the such must be done as part of
+// more sophisticated inverse dynamics.
+template <typename Scalar>
+class InverseKinematics {
+ public:
+ using ModuleParams = SimplifiedDynamics<Scalar>::ModuleParams;
+ using Parameters = SimplifiedDynamics<Scalar>::Parameters;
+ using States = SimplifiedDynamics<Scalar>::States;
+ using State = SimplifiedDynamics<Scalar>::template VelocityState<Scalar>;
+ InverseKinematics(const Parameters ¶ms) : params_(params) {}
+
+ // Uses kVx, kVy, kTheta, and kOmega from the input goal state for the
+ // absolute kinematics. Also uses the specified theta values to bias theta
+ // output values towards the current state (i.e., if the module 0 theta is
+ // currently 0 and we are asked to drive straight backwards, this will prefer
+ // a theta of zero rather than a theta of pi).
+ State Solve(const State &goal) {
+ State result = goal;
+ for (size_t module_index = 0; module_index < params_.modules.size();
+ ++module_index) {
+ SolveModule(goal, params_.modules[module_index].position,
+ &result(States::kThetas0 + 2 * module_index),
+ &result(States::kOmegas0 + 2 * module_index));
+ }
+ return result;
+ }
+
+ void SolveModule(const State &goal,
+ const Eigen::Matrix<Scalar, 2, 1> &module_position,
+ Scalar *module_theta, Scalar *module_omega) {
+ const Scalar vx = goal(States::kVx);
+ const Scalar vy = goal(States::kVy);
+ const Scalar omega = goal(States::kOmega);
+ // module_velocity_in_robot_frame = R(-theta) * robot_vel +
+ // omega.cross(module_position);
+ // module_vel_x = (cos(-theta) * vx - sin(-theta) * vy) - omega * module_y
+ // module_vel_y = (sin(-theta) * vx + cos(-theta) * vy) + omega * module_x
+ // module_theta = atan2(module_vel_y, module_vel_x)
+ // module_omega = datan2(module_vel_y, module_vel_x) / dt
+ // datan2(y, x) / dt = (x * dy/dt - y * dx / dt) / (x^2 + y^2)
+ // robot accelerations are assumed to be zero.
+ // dmodule_vel_x / dt = (sin(-theta) * vx + cos(-theta) * vy) * omega
+ // dmodule_vel_y / dt = (-cos(-theta) * vx + sin(-theta) * vy) * omega
+ const Scalar ctheta = std::cos(-goal(States::kTheta));
+ const Scalar stheta = std::sin(-goal(States::kTheta));
+ const Scalar module_vel_x =
+ (ctheta * vx - stheta * vy) - omega * module_position.y();
+ const Scalar module_vel_y =
+ (stheta * vx + ctheta * vy) + omega * module_position.x();
+ const Scalar nominal_module_theta = atan2(module_vel_y, module_vel_x);
+ // If the current module theta is more than 90 deg from the desired theta,
+ // flip the desired theta by 180 deg.
+ if (std::abs(aos::math::DiffAngle(nominal_module_theta, *module_theta)) >
+ M_PI_2) {
+ *module_theta = aos::math::NormalizeAngle(nominal_module_theta + M_PI);
+ } else {
+ *module_theta = nominal_module_theta;
+ }
+ const Scalar module_accel_x = (stheta * vx + ctheta * vy) * omega;
+ const Scalar module_accel_y = (-ctheta * vx + stheta * vy) * omega;
+ const Scalar module_vel_norm_squared =
+ (module_vel_x * module_vel_x + module_vel_y * module_vel_y);
+ if (module_vel_norm_squared < 1e-5) {
+ // Prevent poor conditioning of module velocities at near-zero speeds.
+ *module_omega = 0.0;
+ } else {
+ *module_omega =
+ (module_vel_x * module_accel_y - module_vel_y * module_accel_x) /
+ module_vel_norm_squared;
+ }
+ }
+
+ private:
+ Parameters params_;
+};
+} // namespace frc971::control_loops::swerve
+#endif // FRC971_CONTROL_LOOPS_SWERVE_INVERSE_KINEMATICS_H_
diff --git a/frc971/control_loops/swerve/inverse_kinematics_test.cc b/frc971/control_loops/swerve/inverse_kinematics_test.cc
new file mode 100644
index 0000000..407d540
--- /dev/null
+++ b/frc971/control_loops/swerve/inverse_kinematics_test.cc
@@ -0,0 +1,150 @@
+#include "frc971/control_loops/swerve/inverse_kinematics.h"
+
+#include "gtest/gtest.h"
+
+namespace frc971::control_loops::swerve::testing {
+class InverseKinematicsTest : public ::testing::Test {
+ protected:
+ typedef double Scalar;
+ using State = InverseKinematics<Scalar>::State;
+ using States = InverseKinematics<Scalar>::States;
+ using ModuleParams = InverseKinematics<Scalar>::ModuleParams;
+ using Parameters = InverseKinematics<Scalar>::Parameters;
+ static ModuleParams MakeModule(const Eigen::Matrix<Scalar, 2, 1> &position) {
+ return ModuleParams{.position = position,
+ .slip_angle_coefficient = 0.0,
+ .slip_angle_alignment_coefficient = 0.0,
+ .steer_motor = KrakenFOC(),
+ .drive_motor = KrakenFOC(),
+ .steer_ratio = 1.0,
+ .drive_ratio = 1.0};
+ }
+ static Parameters MakeParams() {
+ return {.mass = 1.0,
+ .moment_of_inertia = 1.0,
+ .modules = {
+ MakeModule({1.0, 1.0}),
+ MakeModule({-1.0, 1.0}),
+ MakeModule({-1.0, -1.0}),
+ MakeModule({1.0, -1.0}),
+ }};
+ }
+
+ InverseKinematicsTest() : inverse_kinematics_(MakeParams()) {}
+
+ struct Goal {
+ Scalar vx;
+ Scalar vy;
+ Scalar omega;
+ Scalar theta;
+ };
+
+ void CheckState(
+ const Goal &goal, const std::array<Scalar, 4> &expected_thetas,
+ const std::optional<Eigen::Vector4d> &expected_omegas = std::nullopt) {
+ State goal_state = State::Zero();
+ goal_state(States::kVx) = goal.vx;
+ goal_state(States::kVy) = goal.vy;
+ goal_state(States::kOmega) = goal.omega;
+ goal_state(States::kTheta) = goal.theta;
+ SCOPED_TRACE(goal_state.bottomRows<4>().transpose());
+ const State nominal_state = inverse_kinematics_.Solve(goal_state);
+ // Now, calculate the numerical derivative of the state and validate that it
+ // matches expectations.
+ const Scalar kDt = 1e-5;
+ const Scalar dtheta = kDt * goal.omega;
+ goal_state(States::kTheta) += dtheta / 2.0;
+ const State state_eps_pos = inverse_kinematics_.Solve(goal_state);
+ goal_state(States::kTheta) -= dtheta;
+ const State state_eps_neg = inverse_kinematics_.Solve(goal_state);
+ const State state_derivative = (state_eps_pos - state_eps_neg) / kDt;
+ for (size_t module_index = 0; module_index < 4; ++module_index) {
+ SCOPED_TRACE(module_index);
+ const int omega_idx = States::kOmegas0 + 2 * module_index;
+ const int theta_idx = States::kThetas0 + 2 * module_index;
+ EXPECT_NEAR(nominal_state(omega_idx), state_derivative(theta_idx), 1e-10);
+ EXPECT_NEAR(nominal_state(theta_idx), expected_thetas[module_index],
+ 1e-10);
+ if (expected_omegas.has_value()) {
+ EXPECT_NEAR(nominal_state(omega_idx),
+ expected_omegas.value()(module_index), 1e-10);
+ }
+ }
+ }
+
+ InverseKinematics<Scalar> inverse_kinematics_;
+};
+
+// Tests that if we are driving straight with no yaw that we get sane
+// kinematics.
+TEST_F(InverseKinematicsTest, StraightDrivingNoYaw) {
+ // Sanity-check zero-speed operation.
+ CheckState({.vx = 0.0, .vy = 0.0, .omega = 0.0, .theta = 0.0},
+ {0.0, 0.0, 0.0, 0.0}, Eigen::Vector4d::Zero());
+
+ CheckState({.vx = 1.0, .vy = 0.0, .omega = 0.0, .theta = 0.0},
+ {0.0, 0.0, 0.0, 0.0}, Eigen::Vector4d::Zero());
+ // Reverse should prefer to bias the modules towards [-pi/2, pi/2] due to
+ // hysteresis from starting the modules at thetas of 0.
+ CheckState({.vx = -1.0, .vy = 0.0, .omega = 0.0, .theta = 0.0},
+ {0.0, 0.0, 0.0, 0.0}, Eigen::Vector4d::Zero());
+
+ CheckState({.vx = 0.0, .vy = 1.0, .omega = 0.0, .theta = 0.0},
+ {M_PI_2, M_PI_2, M_PI_2, M_PI_2}, Eigen::Vector4d::Zero());
+ // For module hysteresis, this is a corner case where we are exactly 90 deg
+ // from the current value; the exact result is unimportant.
+ CheckState({.vx = 0.0, .vy = -1.0, .omega = 0.0, .theta = 0.0},
+ {-M_PI_2, -M_PI_2, -M_PI_2, -M_PI_2}, Eigen::Vector4d::Zero());
+
+ CheckState({.vx = 1.0, .vy = 1.0, .omega = 0.0, .theta = 0.0},
+ {M_PI_4, M_PI_4, M_PI_4, M_PI_4}, Eigen::Vector4d::Zero());
+ // Reverse should prefer to bias the modules towards [-pi/2, pi/2] due to
+ // hysteresis from starting the modules at thetas of 0.
+ CheckState({.vx = -1.0, .vy = -1.0, .omega = 0.0, .theta = 0.0},
+ {M_PI_4, M_PI_4, M_PI_4, M_PI_4}, Eigen::Vector4d::Zero());
+}
+
+// Tests that if we are driving straight with non-zero yaw that we get sane
+// kinematics.
+TEST_F(InverseKinematicsTest, StraightDrivingYawed) {
+ CheckState({.vx = 1.0, .vy = 0.0, .omega = 0.0, .theta = 0.1},
+ {-0.1, -0.1, -0.1, -0.1}, Eigen::Vector4d::Zero());
+
+ CheckState({.vx = 1.0, .vy = 0.0, .omega = 0.0, .theta = M_PI_2},
+ {-M_PI_2, -M_PI_2, -M_PI_2, -M_PI_2}, Eigen::Vector4d::Zero());
+ CheckState({.vx = 1.0, .vy = 0.0, .omega = 0.0, .theta = M_PI},
+ {0.0, 0.0, 0.0, 0.0}, Eigen::Vector4d::Zero());
+
+ CheckState({.vx = 0.0, .vy = 1.0, .omega = 0.0, .theta = M_PI_2},
+ {0.0, 0.0, 0.0, 0.0}, Eigen::Vector4d::Zero());
+ // Reverse should prefer to bias the modules towards [-pi/2, pi/2] due to
+ // hysteresis from starting the modules at thetas of 0.
+ CheckState({.vx = 0.0, .vy = -1.0, .omega = 0.0, .theta = M_PI_2},
+ {0.0, 0.0, 0.0, 0.0}, Eigen::Vector4d::Zero());
+}
+
+// Tests that we can spin in place.
+TEST_F(InverseKinematicsTest, SpinInPlace) {
+ CheckState({.vx = 0.0, .vy = 0.0, .omega = 1.0, .theta = 0.0},
+ {-M_PI_4, M_PI_4, -M_PI_4, M_PI_4}, Eigen::Vector4d::Zero());
+ // And changing the current theta should not matter.
+ CheckState({.vx = 0.0, .vy = 0.0, .omega = 1.0, .theta = 1.0},
+ {-M_PI_4, M_PI_4, -M_PI_4, M_PI_4}, Eigen::Vector4d::Zero());
+}
+
+// Tests that if we are spinning while moving that we correctly calculate module
+// yaw rates.
+TEST_F(InverseKinematicsTest, SpinWhileMoving) {
+ // Set up a situation where we are driving straight forwards, with a
+ // yaw rate of 1 rad / sec.
+ // The modules are all at radii of sqrt(2), so the contribution from both
+ // the translational and rotational velocities should be equal; in this case
+ // each module will have an angle that is an equal combination of straight
+ // forwards (0 deg) and some 45 deg offset.
+ CheckState({.vx = std::sqrt(static_cast<Scalar>(2)),
+ .vy = 0.0,
+ .omega = 1.0,
+ .theta = 0.0},
+ {3 * M_PI_4 / 2, -3 * M_PI_4 / 2, -M_PI_4 / 2, M_PI_4 / 2});
+}
+} // namespace frc971::control_loops::swerve::testing
diff --git a/frc971/control_loops/swerve/jax_dynamics.py b/frc971/control_loops/swerve/jax_dynamics.py
index 19a48de..58d5fcf 100644
--- a/frc971/control_loops/swerve/jax_dynamics.py
+++ b/frc971/control_loops/swerve/jax_dynamics.py
@@ -354,6 +354,7 @@
])
+@jax.jit
def mpc_cost(coefficients: CoefficientsType, X, U, goal):
J = 0
diff --git a/frc971/control_loops/swerve/linearization_utils.h b/frc971/control_loops/swerve/linearization_utils.h
new file mode 100644
index 0000000..ee1275e
--- /dev/null
+++ b/frc971/control_loops/swerve/linearization_utils.h
@@ -0,0 +1,13 @@
+#ifndef FRC971_CONTROL_LOOPS_SWERVE_LINEARIZATION_UTILS_H_
+#define FRC971_CONTROL_LOOPS_SWERVE_LINEARIZATION_UTILS_H_
+
+namespace frc971::control_loops::swerve {
+template <typename State, typename Input>
+struct DynamicsInterface {
+ virtual ~DynamicsInterface() {}
+ // To be overridden by the implementation; returns the derivative of the state
+ // at the given state with the provided control input.
+ virtual State operator()(const State &X, const Input &U) const = 0;
+};
+} // namespace frc971::control_loops::swerve
+#endif // FRC971_CONTROL_LOOPS_SWERVE_LINEARIZATION_UTILS_H_
diff --git a/frc971/control_loops/swerve/linearized_controller.h b/frc971/control_loops/swerve/linearized_controller.h
new file mode 100644
index 0000000..1df9640
--- /dev/null
+++ b/frc971/control_loops/swerve/linearized_controller.h
@@ -0,0 +1,126 @@
+#include <memory>
+
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include <Eigen/Dense>
+
+#include "frc971/control_loops/c2d.h"
+#include "frc971/control_loops/dlqr.h"
+#include "frc971/control_loops/jacobian.h"
+#include "frc971/control_loops/swerve/dynamics.h"
+#include "frc971/control_loops/swerve/linearization_utils.h"
+
+namespace frc971::control_loops::swerve {
+
+// Provides a simple LQR controller that takes a non-linear system, linearizes
+// the dynamics at each timepoint, recalculates the LQR gains for those
+// dynamics, and calculates the relevant feedback inputs to provide.
+template <int NStates, typename Scalar = double>
+class LinearizedController {
+ public:
+ typedef Eigen::Matrix<Scalar, NStates, 1> State;
+ typedef Eigen::Matrix<Scalar, NStates, NStates> StateSquare;
+ typedef Eigen::Matrix<Scalar, kNumInputs, 1> Input;
+ typedef Eigen::Matrix<Scalar, kNumInputs, kNumInputs> InputSquare;
+ typedef Eigen::Matrix<Scalar, NStates, kNumInputs> BMatrix;
+ typedef DynamicsInterface<State, Input> Dynamics;
+
+ struct Parameters {
+ // State cost matrix.
+ StateSquare Q;
+ // Input cost matrix.
+ InputSquare R;
+ // period at which the controller is called.
+ std::chrono::nanoseconds dt;
+ // The dynamics to use.
+ // TODO(james): I wrote this before creating the auto-differentiation
+ // functions; we should swap to the auto-differentiation, since the
+ // numerical linearization is one of the bigger timesinks in this controller
+ // right now.
+ std::unique_ptr<Dynamics> dynamics;
+ };
+
+ // Represents the linearized dynamics of the system.
+ struct LinearDynamics {
+ StateSquare A;
+ BMatrix B;
+ };
+
+ // Debug information for a given cycle of the controller.
+ struct ControllerDebug {
+ // Feedforward input which we provided.
+ Input U_ff;
+ // Calculated feedback input to provide.
+ Input U_feedback;
+ Eigen::Matrix<Scalar, kNumInputs, NStates> feedback_contributions;
+ };
+
+ struct ControllerResult {
+ // Control input to provide to the robot.
+ Input U;
+ ControllerDebug debug;
+ };
+
+ LinearizedController(Parameters params) : params_(std::move(params)) {}
+
+ // Runs the controller for a given iteration, relinearizing the dynamics about
+ // the provided current state X, attempting to control the robot to the
+ // desired goal state.
+ // The U_ff input will be added into the returned control input.
+ ControllerResult RunController(const State &X, const State &goal,
+ Input U_ff) {
+ auto start_time = aos::monotonic_clock::now();
+ // TODO(james): Swap this to the auto-diff methods; this is currently about
+ // a third of the total time spent in this method when run on the roborio.
+ const struct LinearDynamics continuous_dynamics =
+ LinearizeDynamics(X, U_ff);
+ auto linearization_time = aos::monotonic_clock::now();
+ struct LinearDynamics discrete_dynamics;
+ frc971::controls::C2D(continuous_dynamics.A, continuous_dynamics.B,
+ params_.dt, &discrete_dynamics.A,
+ &discrete_dynamics.B);
+ auto c2d_time = aos::monotonic_clock::now();
+ VLOG(2) << "Controllability of dynamics (ideally should be " << NStates
+ << "): "
+ << frc971::controls::Controllability(discrete_dynamics.A,
+ discrete_dynamics.B);
+ Eigen::Matrix<Scalar, kNumInputs, NStates> K;
+ Eigen::Matrix<Scalar, NStates, NStates> S;
+ // TODO(james): Swap this to a cheaper DARE solver; we should probably just
+ // do something like we do in Trajectory::CalculatePathGains for the tank
+ // spline controller where we approximate the infinite-horizon DARE solution
+ // by doing a finite-horizon LQR.
+ // Currently the dlqr call represents ~60% of the time spent in the
+ // RunController() method.
+ frc971::controls::dlqr(discrete_dynamics.A, discrete_dynamics.B, params_.Q,
+ params_.R, &K, &S);
+ auto dlqr_time = aos::monotonic_clock::now();
+ const Input U_feedback = K * (goal - X);
+ const Input U = U_ff + U_feedback;
+ Eigen::Matrix<Scalar, kNumInputs, NStates> feedback_contributions;
+ for (int state_idx = 0; state_idx < NStates; ++state_idx) {
+ feedback_contributions.col(state_idx) =
+ K.col(state_idx) * (goal - X)(state_idx);
+ }
+ VLOG(2) << "linearization time "
+ << aos::time::DurationInSeconds(linearization_time - start_time)
+ << " c2d time "
+ << aos::time::DurationInSeconds(c2d_time - linearization_time)
+ << " dlqr time "
+ << aos::time::DurationInSeconds(dlqr_time - c2d_time);
+ return {.U = U,
+ .debug = {.U_ff = U_ff,
+ .U_feedback = U_feedback,
+ .feedback_contributions = feedback_contributions}};
+ }
+
+ LinearDynamics LinearizeDynamics(const State &X, const Input &U) {
+ return {.A = NumericalJacobianX(*params_.dynamics, X, U),
+ .B = NumericalJacobianU(*params_.dynamics, X, U)};
+ }
+
+ private:
+ const Parameters params_;
+};
+
+} // namespace frc971::control_loops::swerve
diff --git a/frc971/control_loops/swerve/linearized_controller_test.cc b/frc971/control_loops/swerve/linearized_controller_test.cc
new file mode 100644
index 0000000..d5e7aee
--- /dev/null
+++ b/frc971/control_loops/swerve/linearized_controller_test.cc
@@ -0,0 +1,87 @@
+#include "frc971/control_loops/swerve/linearized_controller.h"
+
+#include "gtest/gtest.h"
+
+namespace frc971::control_loops::swerve::test {
+
+class LinearizedControllerTest : public ::testing::Test {
+ protected:
+ typedef LinearizedController<2> Controller;
+ typedef Controller::State State;
+ typedef Controller::Input Input;
+ struct LinearDynamics : public Controller::Dynamics {
+ Eigen::Vector2d operator()(
+ const Eigen::Vector2d &X,
+ const Eigen::Matrix<double, 8, 1> &U) const override {
+ return Eigen::Matrix2d{{0.0, 1.0}, {0.0, -0.01}} * X +
+ Eigen::Matrix<double, 2, 8>{
+ {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0},
+ {1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}} *
+ U;
+ }
+ };
+ static Eigen::Matrix<double, kNumInputs, kNumInputs> MakeR() {
+ Eigen::Matrix<double, kNumInputs, kNumInputs> R;
+ R.setIdentity();
+ return R;
+ }
+ LinearizedControllerTest()
+ : controller_({.Q = Eigen::Matrix<double, 2, 2>{{1.0, 0.0}, {0.0, 1.0}},
+ .R = MakeR(),
+ .dt = std::chrono::milliseconds(10),
+ .dynamics = std::make_unique<LinearDynamics>()}) {}
+
+ Controller controller_;
+};
+
+// Sanity check that the dynamics linearization is working correctly.
+TEST_F(LinearizedControllerTest, LinearizedDynamics) {
+ auto dynamics = controller_.LinearizeDynamics(State::Zero(), Input::Zero());
+ EXPECT_EQ(0.0, dynamics.A(0, 0));
+ EXPECT_EQ(0.0, dynamics.A(1, 0));
+ EXPECT_EQ(1.0, dynamics.A(0, 1));
+ EXPECT_EQ(-0.01, dynamics.A(1, 1));
+ // All elements of B except for (1, 0) should be exactly 0.
+ EXPECT_EQ(1.0, dynamics.B(1, 0));
+ EXPECT_EQ(1.0, dynamics.B.norm());
+}
+
+// Confirm that the generated LQR controller is able to generate correct
+// inputs when state and goal are at zero.
+TEST_F(LinearizedControllerTest, ControllerResultAtZero) {
+ auto result =
+ controller_.RunController(State::Zero(), State::Zero(), Input::Zero());
+ EXPECT_EQ(0.0, result.U.norm());
+ EXPECT_EQ(0.0, result.debug.U_ff.norm());
+ EXPECT_EQ(0.0, result.debug.U_feedback.norm());
+}
+
+// Confirm that the generated LQR controller is able to generate correct
+// inputs when state is zero and the goal is non-zero.
+TEST_F(LinearizedControllerTest, ControlToZero) {
+ auto result = controller_.RunController(State::Zero(), State{{1.0}, {0.0}},
+ Input::Zero());
+ EXPECT_LT(0.0, result.U(0, 0));
+ // All other U inputs should be zero.
+ EXPECT_EQ(0.0, result.U.bottomRows<7>().norm());
+ EXPECT_EQ(0.0, result.debug.U_ff.norm());
+ EXPECT_EQ(0.0,
+ (result.U - (result.debug.U_ff + result.debug.U_feedback)).norm());
+}
+
+// Confirm that the generated LQR controller is able to pass through the
+// feedforwards when we have no difference between the goal and the current
+// state.
+TEST_F(LinearizedControllerTest, ControlToNonzeroState) {
+ const State state{{1.0}, {1.0}};
+ auto result = controller_.RunController(
+ state, state,
+ Input{{1.0}, {0.0}, {0.0}, {0.0}, {0.0}, {0.0}, {0.0}, {0.0}});
+ EXPECT_EQ(1.0, result.U(0, 0));
+ // All other U inputs should be zero.
+ EXPECT_EQ(0.0, result.U.bottomRows<7>().norm());
+ EXPECT_EQ(0.0, result.debug.U_feedback.norm());
+ EXPECT_EQ(0.0,
+ (result.U - (result.debug.U_ff + result.debug.U_feedback)).norm());
+}
+} // namespace frc971::control_loops::swerve::test
diff --git a/frc971/control_loops/swerve/multi_experience_collector.py b/frc971/control_loops/swerve/multi_experience_collector.py
new file mode 100644
index 0000000..d0551e1
--- /dev/null
+++ b/frc971/control_loops/swerve/multi_experience_collector.py
@@ -0,0 +1,54 @@
+#!/usr/bin/env python3
+from absl import app
+from absl import flags
+import sys
+from multiprocessing.pool import ThreadPool
+import pathlib
+import subprocess
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string('outdir', '/tmp/swerve', "Directory to write results to.")
+flags.DEFINE_integer('num_actors', 20, 'Number of actors to run in parallel.')
+flags.DEFINE_integer('num_solutions', 100,
+ 'Number of random problems to solve.')
+
+
+def collect_experience(agent_number):
+ filename = f'{agent_number}'
+ if FLAGS.outdir:
+ subdir = pathlib.Path(FLAGS.outdir) / filename
+ else:
+ subdir = pathlib.Path(filename)
+ subdir.mkdir(parents=True, exist_ok=True)
+
+ with open(f'{subdir.resolve()}/log', 'w') as output:
+ subprocess.check_call(
+ args=[
+ sys.executable,
+ "frc971/control_loops/swerve/experience_collector",
+ f"--seed={agent_number}",
+ f"--outputdir={subdir.resolve()}",
+ "--quiet",
+ f"--num_solutions={FLAGS.num_solutions}",
+ ],
+ stdout=output,
+ stderr=output,
+ )
+
+
+def main(argv):
+ # Load a simple problem first so we compile with less system load. This
+ # makes it faster on a processor with frequency boosting.
+ subprocess.check_call(args=[
+ sys.executable,
+ "frc971/control_loops/swerve/experience_collector",
+ "--compileonly",
+ ])
+
+ # Try a bunch of goals now
+ with ThreadPool(FLAGS.num_actors) as pool:
+ pool.starmap(collect_experience, zip(range(FLAGS.num_actors), ))
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/frc971/control_loops/swerve/simplified_dynamics.h b/frc971/control_loops/swerve/simplified_dynamics.h
new file mode 100644
index 0000000..9afe636
--- /dev/null
+++ b/frc971/control_loops/swerve/simplified_dynamics.h
@@ -0,0 +1,372 @@
+#ifndef FRC971_CONTROL_LOOPS_SWERVE_SIMPLIFIED_DYNAMICS_H_
+#define FRC971_CONTROL_LOOPS_SWERVE_SIMPLIFIED_DYNAMICS_H_
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+
+#include "aos/util/math.h"
+#include "frc971/control_loops/swerve/auto_diff_jacobian.h"
+#include "frc971/control_loops/swerve/dynamics.h"
+#include "frc971/control_loops/swerve/motors.h"
+
+namespace frc971::control_loops::swerve {
+
+// Provides a simplified set of physics representing a swerve drivetrain.
+// Broadly speaking, these dynamics model:
+// * Standard motors on the drive and steer axes, with no coupling between
+// the motors.
+// * Assume that the steer direction of each module is only influenced by
+// the inertia of the motor rotor plus some small aligning force.
+// * Assume that torque from the drive motor is transferred directly to the
+// carpet.
+// * Assume that a lateral force on the wheel is generated proportional to the
+// slip angle of the wheel.
+//
+// This class is templated on a Scalar that is used to determine whether you
+// want to use double or single-precision floats for the calculations here.
+//
+// Several individual methods on this class are also templated on a LocalScalar
+// type. This is provided to allow those methods to be called with ceres Jets to
+// do autodifferentiation within various solvers/jacobian calculators.
+template <typename Scalar = double>
+class SimplifiedDynamics {
+ public:
+ struct ModuleParams {
+ // Module position relative to the center of mass of the robot.
+ Eigen::Matrix<Scalar, 2, 1> position;
+ // Coefficient dictating how much sideways force is generated at a given
+ // slip angle. Units are effectively Newtons / radian of slip.
+ Scalar slip_angle_coefficient;
+ // Coefficient dicating how much the steer wheel is forced into alignment
+ // by the motion of the wheel over the ground (i.e., we are assuming that
+ // if you push the robot along it will cause the wheels to eventually
+ // align with the direction of motion).
+ // In radians / sec^2 / radians of slip angle.
+ Scalar slip_angle_alignment_coefficient;
+ // Parameters for the steer and drive motors.
+ Motor steer_motor;
+ Motor drive_motor;
+ // radians of module steering = steer_ratio * radians of motor shaft
+ Scalar steer_ratio;
+ // meters of driving = drive_ratio * radians of motor shaft
+ Scalar drive_ratio;
+ };
+ struct Parameters {
+ // Mass of the robot, in kg.
+ Scalar mass;
+ // Moment of inertia of the robot about the yaw axis, in kg * m^2.
+ Scalar moment_of_inertia;
+ // Note: While this technically would support an arbitrary number of
+ // modules, the statically-sized state vectors do limit us to 4 modules
+ // currently, and it should not be counted on that other pieces of code will
+ // be able to support non-4-module swerves.
+ std::vector<ModuleParams> modules;
+ };
+ enum States {
+ // Thetas* and Omegas* are the yaw and yaw rate of the indicated modules.
+ // (note that we do not actually need to track drive speed per module,
+ // as with current control we have the ability to directly command torque
+ // to those motors; however, if we wished to account for the saturation
+ // limits on the motor, then we would need to have access to those states,
+ // although they can be fully derived from the robot vx, vy, theta, and
+ // omega).
+ kThetas0 = 0,
+ kOmegas0 = 1,
+ kThetas1 = 2,
+ kOmegas1 = 3,
+ kThetas2 = 4,
+ kOmegas2 = 5,
+ kThetas3 = 6,
+ kOmegas3 = 7,
+ // Robot yaw, in radians.
+ kTheta = 8,
+ // Robot speed in the global frame, in meters / sec.
+ kVx = 9,
+ kVy = 10,
+ // Robot yaw rate, in radians / sec.
+ kOmega = 11,
+ kNumVelocityStates = 12,
+ // Augmented states for doing position control.
+ // Robot X position in the global frame.
+ kX = 12,
+ // Robot Y position in the global frame.
+ kY = 13,
+ kNumPositionStates = 14,
+ };
+ using Inputs = InputStates::States;
+
+ template <typename ScalarT = Scalar>
+ using VelocityState = Eigen::Matrix<ScalarT, kNumVelocityStates, 1>;
+ template <typename ScalarT = Scalar>
+ using PositionState = Eigen::Matrix<ScalarT, kNumPositionStates, 1>;
+ template <typename ScalarT = Scalar>
+ using VelocityStateSquare =
+ Eigen::Matrix<ScalarT, kNumVelocityStates, kNumVelocityStates>;
+ template <typename ScalarT = Scalar>
+ using PositionStateSquare =
+ Eigen::Matrix<ScalarT, kNumPositionStates, kNumPositionStates>;
+ template <typename ScalarT = Scalar>
+ using PositionBMatrix =
+ Eigen::Matrix<ScalarT, kNumPositionStates, kNumInputs>;
+ template <typename ScalarT = Scalar>
+ using Input = Eigen::Matrix<ScalarT, kNumInputs, 1>;
+
+ SimplifiedDynamics(const Parameters ¶ms) : params_(params) {
+ for (size_t module_index = 0; module_index < params_.modules.size();
+ ++module_index) {
+ module_dynamics_.emplace_back(params_, module_index);
+ }
+ }
+
+ // Returns the derivative of state for the given state and input.
+ template <typename LocalScalar>
+ PositionState<LocalScalar> Dynamics(const PositionState<LocalScalar> &state,
+ const Input<LocalScalar> &input) const {
+ PositionState<LocalScalar> Xdot = PositionState<LocalScalar>::Zero();
+
+ for (const ModuleDynamics &module : module_dynamics_) {
+ Xdot += module.PartialDynamics(state, input);
+ }
+
+ // And finally catch the global states:
+ Xdot(kX) = state(kVx);
+ Xdot(kY) = state(kVy);
+ Xdot(kTheta) = state(kOmega);
+
+ return Xdot;
+ }
+
+ template <typename LocalScalar>
+ VelocityState<LocalScalar> VelocityDynamics(
+ const VelocityState<LocalScalar> &state,
+ const Input<LocalScalar> &input) const {
+ PositionState<LocalScalar> input_state = PositionState<LocalScalar>::Zero();
+ input_state.template topRows<kNumVelocityStates>() = state;
+ return Dynamics(input_state, input).template topRows<kNumVelocityStates>();
+ }
+
+ std::pair<PositionStateSquare<>, PositionBMatrix<>> LinearizedDynamics(
+ const PositionState<> &state, const Input<> &input) {
+ DynamicsFunctor functor(*this);
+ Eigen::Matrix<Scalar, kNumPositionStates + kNumInputs, 1> parameters;
+ parameters.template topRows<kNumPositionStates>() = state;
+ parameters.template bottomRows<kNumInputs>() = input;
+ const Eigen::Matrix<Scalar, kNumPositionStates,
+ kNumPositionStates + kNumInputs>
+ jacobian =
+ AutoDiffJacobian<Scalar, DynamicsFunctor,
+ kNumPositionStates + kNumInputs,
+ kNumPositionStates>::Jacobian(functor, parameters);
+ return {
+ jacobian.template block<kNumPositionStates, kNumPositionStates>(0, 0),
+ jacobian.template block<kNumPositionStates, kNumInputs>(
+ 0, kNumPositionStates)};
+ }
+
+ private:
+ // Wrapper to provide an operator() for the dynamisc class that allows it to
+ // be used by the auto-differentiation code.
+ class DynamicsFunctor {
+ public:
+ DynamicsFunctor(const SimplifiedDynamics &dynamics) : dynamics_(dynamics) {}
+
+ template <typename LocalScalar>
+ Eigen::Matrix<LocalScalar, kNumPositionStates, 1> operator()(
+ const Eigen::Map<const Eigen::Matrix<
+ LocalScalar, kNumPositionStates + kNumInputs, 1>>
+ input) const {
+ return dynamics_.Dynamics(
+ PositionState<LocalScalar>(
+ input.template topRows<kNumPositionStates>()),
+ Input<LocalScalar>(input.template bottomRows<kNumInputs>()));
+ }
+
+ private:
+ const SimplifiedDynamics &dynamics_;
+ };
+
+ // Represents the dynamics of an individual module.
+ class ModuleDynamics {
+ public:
+ ModuleDynamics(const Parameters &robot_params, const size_t module_index)
+ : robot_params_(robot_params), module_index_(module_index) {
+ CHECK_LT(module_index_, robot_params_.modules.size());
+ }
+
+ // This returns the portions of the derivative of state that are due to the
+ // individual module. The result from this function should be able to be
+ // naively summed with the dynamics for each other module plus some global
+ // dynamics (which take care of that e.g. xdot = vx) and give you the
+ // overall dynamics of the system.
+ template <typename LocalScalar>
+ PositionState<LocalScalar> PartialDynamics(
+ const PositionState<LocalScalar> &state,
+ const Input<LocalScalar> &input) const {
+ PositionState<LocalScalar> Xdot = PositionState<LocalScalar>::Zero();
+
+ Xdot(ThetasIdx()) = state(OmegasIdx());
+
+ // Steering dynamics for an individual module assume ~zero friction,
+ // and thus ~the only inertia is from the motor rotor itself.
+ // torque_motor = stall_torque / stall_current * current
+ // accel_motor = torque_motor / motor_inertia
+ // accel_steer = accel_motor * steer_ratio
+ const Motor &steer_motor = module_params().steer_motor;
+ const LocalScalar steer_motor_accel =
+ input(IsIdx()) *
+ static_cast<Scalar>(
+ module_params().steer_ratio * steer_motor.stall_torque /
+ (steer_motor.stall_current * steer_motor.motor_inertia));
+
+ // For the impacts of the modules on the overall robot
+ // dynamics (X, Y, and theta acceleration), we calculate the forces
+ // generated by the module and then apply them. These forces come from
+ // two effects in this model:
+ // 1. Slip angle of the module (dependent on the current robot velocity &
+ // module steer angle).
+ // 2. Drive torque from the module (dependent on the current drive
+ // current and module steer angle).
+ // We assume no torque is generated from e.g. the wheel resisting the
+ // steering motion.
+ //
+ // clang-format off
+ //
+ // For slip angle we have:
+ // wheel_velocity = R(-theta - theta_steer) * (
+ // robot_vel + omega.cross(R(theta) * module_position))
+ // slip_angle = -atan2(wheel_velocity)
+ // slip_force = slip_angle_coefficient * slip_angle
+ // slip_force_direction = theta + theta_steer + pi / 2
+ // force_x = slip_force * cos(slip_force_direction)
+ // force_y = slip_force * sin(slip_force_direction)
+ // accel_* = force_* / mass
+ // # And now calculate torque from slip angle.
+ // torque_vec = module_position.cross([slip_force * cos(theta_steer + pi / 2),
+ // slip_force * sin(theta_steer + pi / 2),
+ // 0.0])
+ // torque_vec = module_position.cross([slip_force * -sin(theta_steer),
+ // slip_force * cos(theta_steer),
+ // 0.0])
+ // robot_torque = torque_vec.z()
+ //
+ // For drive torque we have:
+ // drive_force = (drive_current * stall_torque / stall_current) / drive_ratio
+ // drive_force_direction = theta + theta_steer
+ // force_x = drive_force * cos(drive_force_direction)
+ // force_y = drive_force * sin(drive_force_direction)
+ // torque_vec = drive_force * module_position.cross([cos(theta_steer),
+ // sin(theta_steer),
+ // 0.0])
+ // torque = torque_vec.z()
+ //
+ // clang-format on
+
+ const Eigen::Matrix<Scalar, 3, 1> module_position{
+ {module_params().position.x()},
+ {module_params().position.y()},
+ {0.0}};
+ const LocalScalar theta = state(kTheta);
+ const LocalScalar theta_steer = state(ThetasIdx());
+ const Eigen::Matrix<LocalScalar, 3, 1> wheel_velocity_in_global_frame =
+ Eigen::Matrix<LocalScalar, 3, 1>(state(kVx), state(kVy),
+ static_cast<LocalScalar>(0.0)) +
+ (Eigen::Matrix<LocalScalar, 3, 1>(static_cast<LocalScalar>(0.0),
+ static_cast<LocalScalar>(0.0),
+ state(kOmega))
+ .cross(Eigen::AngleAxis<LocalScalar>(
+ theta, Eigen::Matrix<LocalScalar, 3, 1>::UnitZ()) *
+ module_position));
+ const Eigen::Matrix<LocalScalar, 3, 1> wheel_velocity_in_wheel_frame =
+ Eigen::AngleAxis<LocalScalar>(
+ -theta - theta_steer, Eigen::Matrix<LocalScalar, 3, 1>::UnitZ()) *
+ wheel_velocity_in_global_frame;
+ // The complicated dynamics use some obnoxious-looking functions to
+ // try to approximate how the slip angle behaves a low speeds to better
+ // condition the dynamics. Because I couldn't be bothered to copy those
+ // dynamics, instead just bias the slip angle to zero at low speeds.
+ const LocalScalar wheel_speed = wheel_velocity_in_wheel_frame.norm();
+ const Scalar start_speed = 0.1;
+ const LocalScalar heading_truth_proportion = -expm1(
+ /*arbitrary large number=*/static_cast<Scalar>(-100.0) *
+ (wheel_speed - start_speed));
+ const LocalScalar wheel_heading =
+ (wheel_speed < start_speed)
+ ? static_cast<LocalScalar>(0.0)
+ : heading_truth_proportion *
+ atan2(wheel_velocity_in_wheel_frame.y(),
+ wheel_velocity_in_wheel_frame.x());
+
+ // We wrap slip_angle with a sin() not because there is actually a sin()
+ // in the real math but rather because we need to smoothly and correctly
+ // handle slip angles between pi / 2 and 3 * pi / 2.
+ const LocalScalar slip_angle = sin(-wheel_heading);
+ const LocalScalar slip_force =
+ module_params().slip_angle_coefficient * slip_angle;
+ const LocalScalar slip_force_direction =
+ theta + theta_steer + static_cast<Scalar>(M_PI_2);
+ const Eigen::Matrix<LocalScalar, 3, 1> slip_force_vec =
+ slip_force * UnitYawVector<LocalScalar>(slip_force_direction);
+ const LocalScalar slip_torque =
+ module_position
+ .cross(slip_force *
+ UnitYawVector<LocalScalar>(theta_steer +
+ static_cast<Scalar>(M_PI_2)))
+ .z();
+
+ // drive torque calculations
+ const Motor &drive_motor = module_params().drive_motor;
+ const LocalScalar drive_force =
+ input(IdIdx()) * static_cast<Scalar>(drive_motor.stall_torque /
+ drive_motor.stall_current /
+ module_params().drive_ratio);
+ const Eigen::Matrix<LocalScalar, 3, 1> drive_force_vec =
+ drive_force * UnitYawVector<LocalScalar>(theta + theta_steer);
+ const LocalScalar drive_torque =
+ drive_force *
+ module_position.cross(UnitYawVector<LocalScalar>(theta_steer)).z();
+ // We add in an aligning force on the wheels primarily to help provide a
+ // bit of impetus to the controllers/solvers to discourage aggressive
+ // slip angles. If we do not include this, then the dynamics make it look
+ // like there are no losses to using extremely aggressive slip angles.
+ const LocalScalar wheel_alignment_accel =
+ -module_params().slip_angle_alignment_coefficient * slip_angle;
+
+ Xdot(OmegasIdx()) = steer_motor_accel + wheel_alignment_accel;
+ // Sum up all the forces.
+ Xdot(kVx) =
+ (slip_force_vec.x() + drive_force_vec.x()) / robot_params_.mass;
+ Xdot(kVy) =
+ (slip_force_vec.y() + drive_force_vec.y()) / robot_params_.mass;
+ Xdot(kOmega) =
+ (slip_torque + drive_torque) / robot_params_.moment_of_inertia;
+
+ return Xdot;
+ }
+
+ private:
+ template <typename LocalScalar>
+ Eigen::Matrix<LocalScalar, 3, 1> UnitYawVector(LocalScalar yaw) const {
+ return Eigen::Matrix<LocalScalar, 3, 1>{
+ {static_cast<LocalScalar>(cos(yaw))},
+ {static_cast<LocalScalar>(sin(yaw))},
+ {static_cast<LocalScalar>(0.0)}};
+ }
+ size_t ThetasIdx() const { return kThetas0 + 2 * module_index_; }
+ size_t OmegasIdx() const { return kOmegas0 + 2 * module_index_; }
+ size_t IsIdx() const { return Inputs::kIs0 + 2 * module_index_; }
+ size_t IdIdx() const { return Inputs::kId0 + 2 * module_index_; }
+
+ const ModuleParams &module_params() const {
+ return robot_params_.modules[module_index_];
+ }
+
+ const Parameters robot_params_;
+
+ const size_t module_index_;
+ };
+
+ Parameters params_;
+ std::vector<ModuleDynamics> module_dynamics_;
+};
+
+} // namespace frc971::control_loops::swerve
+#endif // FRC971_CONTROL_LOOPS_SWERVE_SIMPLIFIED_DYNAMICS_H_
diff --git a/frc971/control_loops/swerve/simplified_dynamics_test.cc b/frc971/control_loops/swerve/simplified_dynamics_test.cc
new file mode 100644
index 0000000..894a85b
--- /dev/null
+++ b/frc971/control_loops/swerve/simplified_dynamics_test.cc
@@ -0,0 +1,244 @@
+#include "frc971/control_loops/swerve/simplified_dynamics.h"
+
+#include <functional>
+
+#include "absl/log/log.h"
+#include "gtest/gtest.h"
+
+#include "aos/time/time.h"
+#include "frc971/control_loops/jacobian.h"
+
+namespace frc971::control_loops::swerve::testing {
+class SimplifiedDynamicsTest : public ::testing::Test {
+ protected:
+ using Dynamics = SimplifiedDynamics<double>;
+ using PositionState = Dynamics::PositionState<double>;
+ using States = Dynamics::States;
+ using Inputs = Dynamics::Inputs;
+ using Input = Dynamics::Input<double>;
+ using ModuleParams = Dynamics::ModuleParams;
+ using Parameters = Dynamics::Parameters;
+ static ModuleParams MakeModule(const Eigen::Vector2d &position,
+ bool wheel_alignment) {
+ return ModuleParams{
+ .position = position,
+ .slip_angle_coefficient = 200.0,
+ .slip_angle_alignment_coefficient = wheel_alignment ? 1.0 : 0.0,
+ .steer_motor = KrakenFOC(),
+ .drive_motor = KrakenFOC(),
+ .steer_ratio = 0.1,
+ .drive_ratio = 0.01};
+ }
+ static Parameters MakeParams(bool wheel_alignment) {
+ return {.mass = 60,
+ .moment_of_inertia = 2,
+ .modules = {
+ MakeModule({1.0, 1.0}, wheel_alignment),
+ MakeModule({-1.0, 1.0}, wheel_alignment),
+ MakeModule({-1.0, -1.0}, wheel_alignment),
+ MakeModule({1.0, -1.0}, wheel_alignment),
+ }};
+ }
+ SimplifiedDynamicsTest() : dynamics_(MakeParams(false)) {}
+
+ PositionState ValidateDynamics(const PositionState &state,
+ const Input &input) {
+ const PositionState Xdot = dynamics_.Dynamics(state, input);
+ // Sanity check simple invariants:
+ EXPECT_EQ(Xdot(Dynamics::kX), state(Dynamics::kVx));
+ EXPECT_EQ(Xdot(Dynamics::kY), state(Dynamics::kVy));
+ EXPECT_EQ(Xdot(Dynamics::kTheta), state(Dynamics::kOmega));
+
+ // Check that the dynamics linearization produces numbers that match numeric
+ // differentiation of the dynamics.
+ aos::monotonic_clock::time_point start_time = aos::monotonic_clock::now();
+ const auto linearized_dynamics = dynamics_.LinearizedDynamics(state, input);
+ const aos::monotonic_clock::duration auto_diff_time =
+ aos::monotonic_clock::now() - start_time;
+
+ start_time = aos::monotonic_clock::now();
+ auto numerical_A = NumericalJacobianX(
+ std::bind(&Dynamics::Dynamics<double>, &dynamics_,
+ std::placeholders::_1, std::placeholders::_2),
+ state, input);
+ auto numerical_B = NumericalJacobianU(
+ std::bind(&Dynamics::Dynamics<double>, &dynamics_,
+ std::placeholders::_1, std::placeholders::_2),
+ state, input);
+ const aos::monotonic_clock::duration numerical_time =
+ aos::monotonic_clock::now() - start_time;
+ VLOG(1) << "Autodifferentiation took " << auto_diff_time
+ << " while numerical differentiation took " << numerical_time;
+ EXPECT_LT((numerical_A - linearized_dynamics.first).norm(), 1e-6)
+ << "Numerical result:\n"
+ << numerical_A << "\nAuto-diff result:\n"
+ << linearized_dynamics.first;
+ EXPECT_LT((numerical_B - linearized_dynamics.second).norm(), 1e-6)
+ << "Numerical result:\n"
+ << numerical_B << "\nAuto-diff result:\n"
+ << linearized_dynamics.second;
+ return Xdot;
+ }
+
+ Dynamics dynamics_;
+};
+
+// Test that if all states and inputs are at zero that the robot won't move.
+TEST_F(SimplifiedDynamicsTest, ZeroIsZero) {
+ EXPECT_EQ(PositionState::Zero(),
+ ValidateDynamics(PositionState::Zero(), Input::Zero()));
+}
+
+// Test that if we are travelling straight forwards with zero inputs that we
+// just coast.
+TEST_F(SimplifiedDynamicsTest, CoastForwards) {
+ PositionState state = PositionState::Zero();
+ state(States::kVx) = 1.0;
+ PositionState expected = PositionState::Zero();
+ expected(States::kX) = 1.0;
+ EXPECT_EQ(expected, ValidateDynamics(state, Input::Zero()));
+}
+
+// Tests that we can accelerate the robot.
+TEST_F(SimplifiedDynamicsTest, AccelerateStraight) {
+ // Check that the drive currents behave as anticipated and accelerate the
+ // robot.
+ Input input{{0.0}, {1.0}, {0.0}, {1.0}, {0.0}, {1.0}, {0.0}, {1.0}};
+ PositionState state = PositionState::Zero();
+ state(States::kVx) = 0.0;
+ PositionState result = ValidateDynamics(state, input);
+ EXPECT_EQ(result.norm(), result(States::kVx));
+ EXPECT_LT(0.1, result(States::kVx));
+}
+
+// Test that if we are driving straight sideways (so our wheel are at 90
+// degrees) that we experience a force slowing us down.
+TEST_F(SimplifiedDynamicsTest, ForceWheelsSideways) {
+ PositionState state = PositionState::Zero();
+ state(States::kVy) = 1.0;
+ PositionState result = ValidateDynamics(state, Input::Zero());
+ EXPECT_LT(result.topRows<States::kVy>().norm(), 1e-10)
+ << ": All derivatives prior to the vy state should be ~zero.\n"
+ << result;
+ EXPECT_LT(result(States::kVy), -1.0)
+ << ": expected non-trivial deceleration.";
+ EXPECT_EQ(result(States::kOmega), 0.0);
+}
+
+// Tests that we can make the robot spin in place by orienting all the wheels
+TEST_F(SimplifiedDynamicsTest, SpinInPlaceNoSlip) {
+ PositionState state = PositionState::Zero();
+ state(States::kThetas0) = 3.0 * M_PI / 4.0;
+ state(States::kThetas1) = 5.0 * M_PI / 4.0;
+ state(States::kThetas2) = 7.0 * M_PI / 4.0;
+ state(States::kThetas3) = 1.0 * M_PI / 4.0;
+ state(States::kOmega) = 1.0;
+ PositionState result = ValidateDynamics(state, Input::Zero());
+ EXPECT_NEAR(result.norm(), 1.0, 1e-10)
+ << ": Only non-zero state should be kTheta, which should be exactly 1.0.";
+ EXPECT_EQ(result(States::kTheta), 1.0);
+
+ // Sanity check that when we then apply drive torque to the wheels that that
+ // apins the robot more.
+ Input input{{0.0}, {1.0}, {0.0}, {1.0}, {0.0}, {1.0}, {0.0}, {1.0}};
+ result = ValidateDynamics(state, input);
+ EXPECT_EQ(result(States::kTheta), 1.0);
+ EXPECT_LT(1.0, result(States::kOmega));
+ // Everything else should be ~zero.
+ result(States::kTheta) = 0.0;
+ result(States::kOmega) = 0.0;
+ EXPECT_LT(result.norm(), 1e-10);
+}
+
+// Tests that we can spin in place when skid-steering (i.e., all wheels stay
+// pointed straight, but we still attempt to spint he robot).
+TEST_F(SimplifiedDynamicsTest, SpinInPlaceSkidSteer) {
+ PositionState state = PositionState::Zero();
+ state(States::kThetas0) = -M_PI;
+ state(States::kThetas1) = -M_PI;
+ state(States::kThetas2) = 0.0;
+ state(States::kThetas3) = 0.0;
+ state(States::kOmega) = 1.0;
+ PositionState result = ValidateDynamics(state, Input::Zero());
+ EXPECT_EQ(result(States::kTheta), 1.0);
+ EXPECT_LT(result(States::kOmega), -1.0)
+ << "We should be aggressively decelerrating when slipping wheels.";
+ // Everything else should be ~zero.
+ result(States::kTheta) = 0.0;
+ result(States::kOmega) = 0.0;
+ EXPECT_LT(result.norm(), 1e-10);
+
+ // Sanity check that when we then apply drive torque to the wheels that that
+ // we can counteract the spin.
+ Input input{{0.0}, {100.0}, {0.0}, {100.0}, {0.0}, {100.0}, {0.0}, {100.0}};
+ result = ValidateDynamics(state, input);
+ EXPECT_EQ(result(States::kTheta), 1.0);
+ EXPECT_LT(1.0, result(States::kOmega));
+ // Everything else should be ~zero.
+ result(States::kTheta) = 0.0;
+ result(States::kOmega) = 0.0;
+ EXPECT_LT(result.norm(), 1e-10);
+}
+
+// Tests that we can spin in place when skid-steering backwards (ensures that
+// slip angle calculations and the such handle the sign changes correctly).
+TEST_F(SimplifiedDynamicsTest, SpinInPlaceSkidSteerBackwards) {
+ PositionState state = PositionState::Zero();
+ state(States::kThetas0) = 0.0;
+ state(States::kThetas1) = 0.0;
+ state(States::kThetas2) = M_PI;
+ state(States::kThetas3) = M_PI;
+ state(States::kOmega) = 1.0;
+ PositionState result = ValidateDynamics(state, Input::Zero());
+ EXPECT_EQ(result(States::kTheta), 1.0);
+ EXPECT_LT(result(States::kOmega), -1.0)
+ << "We should be aggressively decelerrating when slipping wheels.";
+ // Everything else should be ~zero.
+ result(States::kTheta) = 0.0;
+ result(States::kOmega) = 0.0;
+ EXPECT_LT(result.norm(), 1e-10);
+
+ // Sanity check that when we then apply drive torque to the wheels that that
+ // we can counteract the spin.
+ Input input{{0.0}, {-100.0}, {0.0}, {-100.0},
+ {0.0}, {-100.0}, {0.0}, {-100.0}};
+ result = ValidateDynamics(state, input);
+ EXPECT_EQ(result(States::kTheta), 1.0);
+ EXPECT_LT(1.0, result(States::kOmega));
+ // Everything else should be ~zero.
+ result(States::kTheta) = 0.0;
+ result(States::kOmega) = 0.0;
+ EXPECT_LT(result.norm(), 1e-10);
+}
+
+// Test that if we turn on the wheel alignment forces that it results in forces
+// that cause the wheels to align to straight over time.
+TEST_F(SimplifiedDynamicsTest, WheelAlignmentForces) {
+ dynamics_ = Dynamics(MakeParams(true));
+ PositionState state = PositionState::Zero();
+ state(States::kThetas0) = -0.1;
+ state(States::kThetas1) = -0.1;
+ state(States::kThetas2) = -0.1;
+ state(States::kThetas3) = -0.1;
+ state(States::kVx) = 1.0;
+ PositionState result = ValidateDynamics(state, Input::Zero());
+ EXPECT_LT(1e-2, result(States::kOmegas0));
+ EXPECT_LT(1e-2, result(States::kOmegas1));
+ EXPECT_LT(1e-2, result(States::kOmegas2));
+ EXPECT_LT(1e-2, result(States::kOmegas3));
+}
+
+// Do some fuzz testing of the jacobian calculations.
+TEST_F(SimplifiedDynamicsTest, Fuzz) {
+ for (size_t state_index = 0; state_index < States::kNumPositionStates;
+ ++state_index) {
+ SCOPED_TRACE(state_index);
+ PositionState state = PositionState::Zero();
+ for (const double value : {-1.0, 0.0, 1.0}) {
+ SCOPED_TRACE(value);
+ state(state_index) = value;
+ ValidateDynamics(state, Input::Zero());
+ }
+ }
+}
+} // namespace frc971::control_loops::swerve::testing
diff --git a/frc971/control_loops/swerve/swerve_drivetrain_goal.fbs b/frc971/control_loops/swerve/swerve_drivetrain_goal.fbs
index eab0d0d..7cbb11c 100644
--- a/frc971/control_loops/swerve/swerve_drivetrain_goal.fbs
+++ b/frc971/control_loops/swerve/swerve_drivetrain_goal.fbs
@@ -1,3 +1,6 @@
+include "frc971/math/matrix.fbs";
+include "frc971/control_loops/swerve/swerve_drivetrain_joystick_goal.fbs";
+
namespace frc971.control_loops.swerve;
// States what translation control type goal we will care about
@@ -22,11 +25,20 @@
translation_speed:double (id: 3);
}
+attribute "static_length";
+
+table LinearVelocityGoal {
+ state:frc971.fbs.Matrix (id: 0);
+ input:frc971.fbs.Matrix (id: 1);
+}
+
table Goal {
front_left_goal:SwerveModuleGoal (id: 0);
front_right_goal:SwerveModuleGoal (id: 1);
back_left_goal:SwerveModuleGoal (id: 2);
back_right_goal:SwerveModuleGoal (id: 3);
+ linear_velocity_goal:LinearVelocityGoal (id: 4);
+ joystick_goal:JoystickGoal (id: 5);
}
root_type Goal;
diff --git a/frc971/control_loops/swerve/swerve_drivetrain_joystick_goal.fbs b/frc971/control_loops/swerve/swerve_drivetrain_joystick_goal.fbs
new file mode 100644
index 0000000..bf44cdd
--- /dev/null
+++ b/frc971/control_loops/swerve/swerve_drivetrain_joystick_goal.fbs
@@ -0,0 +1,9 @@
+namespace frc971.control_loops.swerve;
+
+table JoystickGoal {
+ vx:float (id: 0);
+ vy:float (id: 1);
+ omega:float (id: 2);
+}
+
+root_type JoystickGoal;
diff --git a/frc971/control_loops/swerve/velocity_controller/BUILD b/frc971/control_loops/swerve/velocity_controller/BUILD
index 7d376bd..a0078b6 100644
--- a/frc971/control_loops/swerve/velocity_controller/BUILD
+++ b/frc971/control_loops/swerve/velocity_controller/BUILD
@@ -44,6 +44,8 @@
"@pip//matplotlib",
"@pip//numpy",
"@pip//tensorflow",
+ "@pip//tensorflow_probability",
+ "@pip//tf_keras",
],
)
@@ -59,6 +61,30 @@
)
py_binary(
+ name = "plot",
+ srcs = [
+ "model.py",
+ "plot.py",
+ ],
+ deps = [
+ ":experience_buffer",
+ ":physics",
+ "//frc971/control_loops/swerve:jax_dynamics",
+ "@pip//absl_py",
+ "@pip//flashbax",
+ "@pip//flax",
+ "@pip//jax",
+ "@pip//jaxtyping",
+ "@pip//matplotlib",
+ "@pip//numpy",
+ "@pip//pygobject",
+ "@pip//tensorflow",
+ "@pip//tensorflow_probability",
+ "@pip//tf_keras",
+ ],
+)
+
+py_binary(
name = "lqr_plot",
srcs = [
"lqr_plot.py",
diff --git a/frc971/control_loops/swerve/velocity_controller/lqr_plot.py b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
index bd6674c..a63e100 100644
--- a/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
+++ b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
@@ -47,9 +47,26 @@
# Container for the data.
Data = collections.namedtuple('Data', [
- 't', 'X', 'X_lqr', 'U', 'U_lqr', 'cost', 'cost_lqr', 'q1_grid', 'q2_grid',
- 'q_grid', 'target_q_grid', 'lqr_grid', 'pi_grid_U', 'lqr_grid_U', 'grid_X',
- 'grid_Y', 'reward', 'reward_lqr', 'step'
+ 't',
+ 'X',
+ 'X_lqr',
+ 'U',
+ 'U_lqr',
+ 'cost',
+ 'cost_lqr',
+ 'q1_grid',
+ 'q2_grid',
+ 'q_grid',
+ 'target_q_grid',
+ 'lqr_grid',
+ 'pi_grid_U',
+ 'lqr_grid_U',
+ 'grid_X',
+ 'grid_Y',
+ 'reward',
+ 'reward_lqr',
+ 'step',
+ 'stdev',
])
FLAGS = absl.flags.FLAGS
@@ -87,8 +104,13 @@
rng = jax.random.key(0)
rng, init_rng = jax.random.split(rng)
- state = create_train_state(init_rng, problem, FLAGS.q_learning_rate,
- FLAGS.pi_learning_rate)
+ state = create_train_state(
+ init_rng,
+ problem,
+ q_learning_rate=FLAGS.q_learning_rate,
+ pi_learning_rate=FLAGS.pi_learning_rate,
+ alpha_learning_rate=FLAGS.alpha_learning_rate,
+ )
state = restore_checkpoint(state, FLAGS.workdir)
if step is not None and state.step == step:
@@ -156,12 +178,23 @@
def compute_pi_U(X, Y):
x = jax.numpy.array([X, Y])
- U, _, _, _ = state.pi_apply(rng,
- state.params,
- observation=state.problem.unwrap_angles(x),
- R=goal,
- deterministic=True)
- return U[0]
+ U, _, _ = state.pi_apply(rng,
+ state.params,
+ observation=state.problem.unwrap_angles(x),
+ R=goal,
+ deterministic=True)
+ return U[0] * problem.action_limit
+
+ def compute_pi_stdev(X, Y):
+ x = jax.numpy.array([X, Y])
+ _, _, std = state.pi_apply(rng,
+ state.params,
+ observation=state.problem.unwrap_angles(x),
+ R=goal,
+ deterministic=True)
+ return std[0]
+
+ std_grid = jax.vmap(jax.vmap(compute_pi_stdev))(grid_X, grid_Y)
lqr_cost_U = jax.vmap(jax.vmap(compute_lqr_U))(grid_X, grid_Y)
pi_cost_U = jax.vmap(jax.vmap(compute_pi_U))(grid_X, grid_Y)
@@ -173,22 +206,25 @@
X, X_lqr, data, params = val
t = data.t.at[i].set(i * problem.dt)
- U, _, _, _ = state.pi_apply(rng,
- params,
- observation=state.problem.unwrap_angles(X),
- R=goal,
- deterministic=True)
+ normalized_U, _, _ = state.pi_apply(
+ rng,
+ params,
+ observation=state.problem.unwrap_angles(X),
+ R=goal,
+ deterministic=True)
U_lqr = problem.F @ (goal - X_lqr)
cost = jax.numpy.minimum(
state.q1_apply(params,
observation=state.problem.unwrap_angles(X),
R=goal,
- action=U),
+ action=normalized_U),
state.q2_apply(params,
observation=state.problem.unwrap_angles(X),
R=goal,
- action=U))
+ action=normalized_U))
+
+ U = normalized_U * problem.action_limit
U_plot = data.U.at[i, :].set(U)
U_lqr_plot = data.U_lqr.at[i, :].set(U_lqr)
@@ -200,8 +236,9 @@
X = problem.A @ X + problem.B @ U
X_lqr = problem.A @ X_lqr + problem.B @ U_lqr
- reward = data.reward - state.problem.cost(X, U, goal)
- reward_lqr = data.reward_lqr - state.problem.cost(X_lqr, U_lqr, goal)
+ reward = data.reward + state.problem.reward(X, normalized_U, goal)
+ reward_lqr = data.reward_lqr + state.problem.reward(
+ X_lqr, U_lqr / problem.action_limit, goal)
return X, X_lqr, data._replace(
t=t,
@@ -242,6 +279,7 @@
reward=0.0,
reward_lqr=0.0,
step=state.step,
+ stdev=std_grid,
), X, X_lqr, state.params)
logging.info('Finished integrating, reward of %f, lqr reward of %f',
@@ -265,6 +303,7 @@
lqr_grid_U=numpy.array(data.lqr_grid_U),
grid_X=numpy.array(data.grid_X),
grid_Y=numpy.array(data.grid_Y),
+ stdev=numpy.array(data.stdev),
reward=float(data.reward),
reward_lqr=float(data.reward_lqr),
step=data.step,
@@ -317,9 +356,10 @@
self.Ufig = pyplot.figure(figsize=pyplot.figaspect(0.5))
self.Uax = [
- self.Ufig.add_subplot(1, 3, 1, projection='3d'),
- self.Ufig.add_subplot(1, 3, 2, projection='3d'),
- self.Ufig.add_subplot(1, 3, 3, projection='3d'),
+ self.Ufig.add_subplot(2, 2, 1, projection='3d'),
+ self.Ufig.add_subplot(2, 2, 2, projection='3d'),
+ self.Ufig.add_subplot(2, 2, 3, projection='3d'),
+ self.Ufig.add_subplot(2, 2, 4, projection='3d'),
]
self.last_trajectory_step = 0
@@ -331,6 +371,7 @@
return
self.last_trajectory_step = data.step
logging.info('Updating trajectory plots')
+ self.fig0.suptitle(f'Step {data.step}')
# Put data in the trajectory plots.
self.x.set_data(data.t, data.X[:, 0])
@@ -362,6 +403,7 @@
return
logging.info('Updating cost plots')
self.last_cost_step = data.step
+ self.costfig.suptitle(f'Step {data.step}')
# Put data in the cost plots.
if hasattr(self, 'costsurf'):
for surf in self.costsurf:
@@ -398,6 +440,7 @@
return
self.last_U_step = data.step
logging.info('Updating U plots')
+ self.Ufig.suptitle(f'Step {data.step}')
# Put data in the controller plots.
if hasattr(self, 'Usurf'):
for surf in self.Usurf:
@@ -407,6 +450,7 @@
(data.lqr_grid_U, 'lqr'),
(data.pi_grid_U, 'pi'),
((data.lqr_grid_U - data.pi_grid_U), 'error'),
+ (data.stdev, 'stdev'),
]
self.Usurf = [
diff --git a/frc971/control_loops/swerve/velocity_controller/main.py b/frc971/control_loops/swerve/velocity_controller/main.py
index a0d248d..391b0fe 100644
--- a/frc971/control_loops/swerve/velocity_controller/main.py
+++ b/frc971/control_loops/swerve/velocity_controller/main.py
@@ -39,6 +39,9 @@
flags.DEFINE_string('workdir', None, 'Directory to store model data.')
+flags.DEFINE_bool('swerve', True,
+ 'If true, train the swerve model, otherwise do the turret.')
+
def main(argv):
if len(argv) > 1:
@@ -64,7 +67,12 @@
FLAGS.workdir,
)
- problem = physics.TurretProblem()
+ if FLAGS.swerve:
+ physics_constants = jax_dynamics.Coefficients()
+ problem = physics.SwerveProblem(physics_constants)
+ else:
+ problem = physics.TurretProblem()
+
state = train.train(FLAGS.workdir, problem)
diff --git a/frc971/control_loops/swerve/velocity_controller/model.py b/frc971/control_loops/swerve/velocity_controller/model.py
index 1394463..7f01eef 100644
--- a/frc971/control_loops/swerve/velocity_controller/model.py
+++ b/frc971/control_loops/swerve/velocity_controller/model.py
@@ -14,9 +14,12 @@
from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from frc971.control_loops.swerve import jax_dynamics
-from frc971.control_loops.swerve import dynamics
from frc971.control_loops.swerve.velocity_controller import physics
from frc971.control_loops.swerve.velocity_controller import experience_buffer
+from tensorflow_probability.substrates import jax as tfp
+
+tfd = tfp.distributions
+tfb = tfp.bijectors
from flax.typing import PRNGKey
@@ -29,6 +32,12 @@
)
absl.flags.DEFINE_float(
+ 'alpha_learning_rate',
+ default=0.004,
+ help='Training learning rate for entropy.',
+)
+
+absl.flags.DEFINE_float(
'q_learning_rate',
default=0.002,
help='Training learning rate.',
@@ -52,6 +61,13 @@
help='Fraction of --pi_learning_rate to reduce by by the end.',
)
+absl.flags.DEFINE_float(
+ 'target_entropy_scalar',
+ default=1.0,
+ help=
+ 'Target entropy scalar for use when using automatic temperature adjustment.',
+)
+
absl.flags.DEFINE_integer(
'replay_size',
default=2000000,
@@ -64,6 +80,30 @@
help='Batch size for learning Q and pi',
)
+absl.flags.DEFINE_boolean(
+ 'skip_layer',
+ default=False,
+ help='If true, add skip layer connections to the Q network.',
+)
+
+absl.flags.DEFINE_boolean(
+ 'rmsnorm',
+ default=False,
+ help='If true, use rmsnorm instead of layer norm.',
+)
+
+absl.flags.DEFINE_boolean(
+ 'dreamer_solver',
+ default=False,
+ help='If true, use the solver from dreamer v3 instead of adam.',
+)
+
+absl.flags.DEFINE_float(
+ 'initial_logalpha',
+ default=0.0,
+ help='The initial value to set logalpha to.',
+)
+
HIDDEN_WEIGHTS = 256
LOG_STD_MIN = -20
@@ -127,36 +167,25 @@
if rng is None:
rng = self.make_rng('pi')
- # Grab a random sample
- random_sample = jax.random.normal(rng, shape=std.shape)
+ pi_distribution = tfd.TransformedDistribution(
+ distribution=tfd.Normal(loc=mu, scale=std),
+ bijector=tfb.Tanh(),
+ )
if deterministic:
# We are testing the optimal policy, just use the mean.
- pi_action = mu
+ pi_action = flax.linen.activation.tanh(mu)
else:
- # Use the reparameterization trick. Adjust the unit gausian with
- # something we can solve for to get the desired noise.
- pi_action = random_sample * std + mu
+ pi_action = pi_distribution.sample(shape=(1, ), seed=rng)
- logp_pi = gaussian_likelihood(random_sample, log_std)
- # Adjustment to log prob
- # NOTE: This formula is a little bit magic. To get an understanding of where it
- # comes from, check out the original SAC paper (arXiv 1801.01290) and look in
- # appendix C. This is a more numerically-stable equivalent to Eq 21.
- delta = (2.0 * (jax.numpy.log(2.0) - pi_action -
- flax.linen.softplus(-2.0 * pi_action)))
+ logp_pi = pi_distribution.log_prob(pi_action)
- if len(delta.shape) > 1:
- delta = jax.numpy.sum(delta, keepdims=True, axis=1)
+ if len(logp_pi.shape) > 1:
+ logp_pi = jax.numpy.sum(logp_pi, keepdims=True, axis=1)
else:
- delta = jax.numpy.sum(delta, keepdims=True)
+ logp_pi = jax.numpy.sum(logp_pi, keepdims=True)
- logp_pi = logp_pi - delta
-
- # Now, saturate the action to the limit using tanh
- pi_action = self.action_limit * flax.linen.activation.tanh(pi_action)
-
- return pi_action, logp_pi, self.action_limit * std, random_sample
+ return pi_action, logp_pi, self.action_limit * std
class MLPQFunction(nn.Module):
@@ -173,10 +202,23 @@
# Estimate Q with a simple multi layer dense network.
x = jax.numpy.hstack((observation, R, action))
for i, hidden_size in enumerate(self.hidden_sizes):
+ # Add d2rl skip layer connections if requested.
+ # Idea from D2RL: https://arxiv.org/pdf/2010.09163.
+ if FLAGS.skip_layer and i != 0:
+ x = jax.numpy.hstack((x, observation, R, action))
+
x = nn.Dense(
name=f'denselayer{i}',
features=hidden_size,
)(x)
+
+ if FLAGS.rmsnorm:
+ # Idea from Dreamerv3: https://arxiv.org/pdf/2301.04104v2.
+ x = nn.RMSNorm(name=f'rmsnorm{i}')(x)
+ else:
+ # Layernorm also improves stability.
+ # Idea from RLPD: https://arxiv.org/pdf/2302.02948.
+ x = nn.LayerNorm(name=f'layernorm{i}')(x)
x = self.activation(x)
x = nn.Dense(
@@ -369,7 +411,7 @@
q_opt_state=q_opt_state,
alpha_tx=alpha_tx,
alpha_opt_state=alpha_opt_state,
- target_entropy=-problem.num_states,
+ target_entropy=-problem.num_outputs * FLAGS.target_entropy_scalar,
mesh=mesh,
sharding=sharding,
replicated_sharding=replicated_sharding,
@@ -380,14 +422,15 @@
def create_train_state(rng: PRNGKey, problem: Problem, q_learning_rate,
- pi_learning_rate):
+ pi_learning_rate, alpha_learning_rate):
"""Creates initial `TrainState`."""
- pi = SquashedGaussianMLPActor(activation=nn.activation.gelu,
+ pi = SquashedGaussianMLPActor(activation=nn.activation.silu,
action_space=problem.num_outputs,
action_limit=problem.action_limit)
# We want q1 and q2 to have different network architectures so they pick up differnet things.
- q1 = MLPQFunction(activation=nn.activation.gelu, hidden_sizes=[128, 256])
- q2 = MLPQFunction(activation=nn.activation.gelu, hidden_sizes=[256, 128])
+ # SiLu is used in DreamerV3 so we use it: https://arxiv.org/pdf/2301.04104v2.
+ q1 = MLPQFunction(activation=nn.activation.silu, hidden_sizes=[128, 256])
+ q2 = MLPQFunction(activation=nn.activation.silu, hidden_sizes=[256, 128])
@jax.jit
def init_params(rng):
@@ -412,7 +455,7 @@
)['params']
if FLAGS.alpha < 0.0:
- logalpha = 0.0
+ logalpha = FLAGS.initial_logalpha
else:
logalpha = jax.numpy.log(FLAGS.alpha)
@@ -423,9 +466,14 @@
'logalpha': logalpha,
}
- pi_tx = optax.sgd(learning_rate=pi_learning_rate)
- q_tx = optax.sgd(learning_rate=q_learning_rate)
- alpha_tx = optax.sgd(learning_rate=q_learning_rate)
+ if FLAGS.dreamer_solver:
+ pi_tx = create_dreamer_solver(learning_rate=pi_learning_rate)
+ q_tx = create_dreamer_solver(learning_rate=q_learning_rate)
+ alpha_tx = create_dreamer_solver(learning_rate=alpha_learning_rate)
+ else:
+ pi_tx = optax.adam(learning_rate=pi_learning_rate)
+ q_tx = optax.adam(learning_rate=q_learning_rate)
+ alpha_tx = optax.adam(learning_rate=alpha_learning_rate)
result = TrainState.create(
problem=problem,
@@ -441,6 +489,89 @@
return result
+# Solver from dreamer v3: https://arxiv.org/pdf/2301.04104v2.
+# TODO(austin): How many of these pieces are actually in optax already?
+def scale_by_rms(beta=0.999, eps=1e-8):
+
+ def init_fn(params):
+ nu = jax.tree_util.tree_map(
+ lambda t: jax.numpy.zeros_like(t, jax.numpy.float32), params)
+ step = jax.numpy.zeros((), jax.numpy.int32)
+ return (step, nu)
+
+ def update_fn(updates, state, params=None):
+ step, nu = state
+ step = optax.safe_int32_increment(step)
+ nu = jax.tree_util.tree_map(
+ lambda v, u: beta * v + (1 - beta) * (u * u), nu, updates)
+ nu_hat = optax.bias_correction(nu, beta, step)
+ updates = jax.tree_util.tree_map(
+ lambda u, v: u / (jax.numpy.sqrt(v) + eps), updates, nu_hat)
+ return updates, (step, nu)
+
+ return optax.GradientTransformation(init_fn, update_fn)
+
+
+def scale_by_agc(clip=0.03, pmin=1e-3):
+
+ def init_fn(params):
+ return ()
+
+ def update_fn(updates, state, params=None):
+
+ def fn(param, update):
+ unorm = jax.numpy.linalg.norm(update.flatten(), 2)
+ pnorm = jax.numpy.linalg.norm(param.flatten(), 2)
+ upper = clip * jax.numpy.maximum(pmin, pnorm)
+ return update * (1 / jax.numpy.maximum(1.0, unorm / upper))
+
+ updates = jax.tree_util.tree_map(fn, params, updates)
+ return updates, ()
+
+ return optax.GradientTransformation(init_fn, update_fn)
+
+
+def scale_by_momentum(beta=0.9, nesterov=False):
+
+ def init_fn(params):
+ mu = jax.tree_util.tree_map(
+ lambda t: jax.numpy.zeros_like(t, jax.numpy.float32), params)
+ step = jax.numpy.zeros((), jax.numpy.int32)
+ return (step, mu)
+
+ def update_fn(updates, state, params=None):
+ step, mu = state
+ step = optax.safe_int32_increment(step)
+ mu = optax.update_moment(updates, mu, beta, 1)
+ if nesterov:
+ mu_nesterov = optax.update_moment(updates, mu, beta, 1)
+ mu_hat = optax.bias_correction(mu_nesterov, beta, step)
+ else:
+ mu_hat = optax.bias_correction(mu, beta, step)
+ return mu_hat, (step, mu)
+
+ return optax.GradientTransformation(init_fn, update_fn)
+
+
+def create_dreamer_solver(
+ learning_rate,
+ agc: float = 0.3,
+ pmin: float = 1e-3,
+ beta1: float = 0.9,
+ beta2: float = 0.999,
+ eps: float = 1e-20,
+ nesterov: bool = False,
+) -> optax.base.GradientTransformation:
+ # From dreamer v3.
+ return optax.chain(
+ # Adaptive gradient clipping.
+ scale_by_agc(agc, pmin),
+ scale_by_rms(beta2, eps),
+ scale_by_momentum(beta1, nesterov),
+ optax.scale_by_learning_rate(learning_rate),
+ )
+
+
def create_learning_rate_fn(
base_learning_rate: float,
final_learning_rate: float,
diff --git a/frc971/control_loops/swerve/velocity_controller/physics.py b/frc971/control_loops/swerve/velocity_controller/physics.py
index b0da0b4..accd0d6 100644
--- a/frc971/control_loops/swerve/velocity_controller/physics.py
+++ b/frc971/control_loops/swerve/velocity_controller/physics.py
@@ -55,13 +55,17 @@
def random_states(self, rng: PRNGKey, dimensions=None):
raise NotImplemented("random_states not implemented")
- def random_actions(self, rng: PRNGKey, dimensions=None):
+ def random_actions(self,
+ rng: PRNGKey,
+ X: jax.typing.ArrayLike,
+ goal: jax.typing.ArrayLike,
+ dimensions=None):
"""Produces a uniformly random action in the action space."""
return jax.random.uniform(
rng,
(dimensions or FLAGS.num_agents, self.num_outputs),
- minval=-self.action_limit,
- maxval=self.action_limit,
+ minval=-1.0,
+ maxval=1.0,
)
def random_goals(self, rng: PRNGKey, dimensions=None):
@@ -94,12 +98,13 @@
A_continuous = jax.numpy.array([[0., 1.], [0., -36.85154548]])
B_continuous = jax.numpy.array([[0.], [56.08534375]])
+ U = U * self.action_limit
return A_continuous @ X + B_continuous @ U
- def cost(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike,
- goal: jax.typing.ArrayLike):
- return (X - goal).T @ jax.numpy.array(
- self.Q) @ (X - goal) + U.T @ jax.numpy.array(self.R) @ U
+ def reward(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike,
+ goal: jax.typing.ArrayLike):
+ return -(X - goal).T @ jax.numpy.array(
+ self.Q) @ (X - goal) - U.T @ jax.numpy.array(self.R) @ U
def random_states(self, rng: PRNGKey, dimensions=None):
rng1, rng2 = jax.random.split(rng)
@@ -120,3 +125,123 @@
maxval=0.1),
jax.numpy.zeros((dimensions or FLAGS.num_agents, 1)),
))
+
+
+class SwerveProblem(Problem):
+
+ def __init__(self, coefficients: jax_dynamics.CoefficientsType):
+ super().__init__(num_states=jax_dynamics.NUM_VELOCITY_STATES,
+ num_unwrapped_states=17,
+ num_outputs=8,
+ num_goals=3,
+ action_limit=40.0)
+
+ self.coefficients = coefficients
+
+ def random_actions(self,
+ rng: PRNGKey,
+ X: jax.typing.ArrayLike,
+ goal: jax.typing.ArrayLike,
+ dimensions=None):
+ """Produces a uniformly random action in the action space."""
+ return jax.random.uniform(
+ rng,
+ (dimensions or FLAGS.num_agents, self.num_outputs),
+ minval=-1.0,
+ maxval=1.0,
+ )
+
+ def unwrap_angles(self, X: jax.typing.ArrayLike):
+ return jax.numpy.stack([
+ jax.numpy.cos(X[..., jax_dynamics.VELOCITY_STATE_THETAS0]),
+ jax.numpy.sin(X[..., jax_dynamics.VELOCITY_STATE_THETAS0]),
+ X[..., jax_dynamics.VELOCITY_STATE_OMEGAS0],
+ jax.numpy.cos(X[..., jax_dynamics.VELOCITY_STATE_THETAS1]),
+ jax.numpy.sin(X[..., jax_dynamics.VELOCITY_STATE_THETAS1]),
+ X[..., jax_dynamics.VELOCITY_STATE_OMEGAS1],
+ jax.numpy.cos(X[..., jax_dynamics.VELOCITY_STATE_THETAS2]),
+ jax.numpy.sin(X[..., jax_dynamics.VELOCITY_STATE_THETAS2]),
+ X[..., jax_dynamics.VELOCITY_STATE_OMEGAS2],
+ jax.numpy.cos(X[..., jax_dynamics.VELOCITY_STATE_THETAS3]),
+ jax.numpy.sin(X[..., jax_dynamics.VELOCITY_STATE_THETAS3]),
+ X[..., jax_dynamics.VELOCITY_STATE_OMEGAS3],
+ jax.numpy.cos(X[..., jax_dynamics.VELOCITY_STATE_THETA]),
+ jax.numpy.sin(X[..., jax_dynamics.VELOCITY_STATE_THETA]),
+ X[..., jax_dynamics.VELOCITY_STATE_VX],
+ X[..., jax_dynamics.VELOCITY_STATE_VY],
+ X[..., jax_dynamics.VELOCITY_STATE_OMEGA],
+ ],
+ axis=-1)
+
+ def xdot(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike):
+ return jax_dynamics.velocity_dynamics(self.coefficients, X,
+ self.action_limit * U)
+
+ def reward(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike,
+ goal: jax.typing.ArrayLike):
+ return -jax_dynamics.mpc_cost(coefficients=self.coefficients,
+ X=X,
+ U=self.action_limit * U,
+ goal=goal)
+
+ def random_states(self, rng: PRNGKey, dimensions=None):
+ rng, rng1, rng2, rng3, rng4, rng5, rng6, rng7, rng8, rng9, rng10, rng11 = jax.random.split(
+ rng, num=12)
+
+ return jax.numpy.hstack((
+ # VELOCITY_STATE_THETAS0 = 0
+ self._random_angle(rng1, dimensions),
+ # VELOCITY_STATE_OMEGAS0 = 1
+ self._random_module_velocity(rng2, dimensions),
+ # VELOCITY_STATE_THETAS1 = 2
+ self._random_angle(rng3, dimensions),
+ # VELOCITY_STATE_OMEGAS1 = 3
+ self._random_module_velocity(rng4, dimensions),
+ # VELOCITY_STATE_THETAS2 = 4
+ self._random_angle(rng5, dimensions),
+ # VELOCITY_STATE_OMEGAS2 = 5
+ self._random_module_velocity(rng6, dimensions),
+ # VELOCITY_STATE_THETAS3 = 6
+ self._random_angle(rng7, dimensions),
+ # VELOCITY_STATE_OMEGAS3 = 7
+ self._random_module_velocity(rng8, dimensions),
+ # VELOCITY_STATE_THETA = 8
+ self._random_angle(rng9, dimensions),
+ # VELOCITY_STATE_VX = 9
+ # VELOCITY_STATE_VY = 10
+ self._random_robot_velocity(rng10, dimensions),
+ # VELOCITY_STATE_OMEGA = 11
+ self._random_robot_angular_velocity(rng11, dimensions),
+ ))
+
+ def random_goals(self, rng: PRNGKey, dimensions=None):
+ """Produces a random goal in the goal space."""
+ return jax.numpy.hstack((
+ jax.random.uniform(rng, (dimensions or FLAGS.num_agents, 1),
+ minval=1.0,
+ maxval=1.0),
+ jax.numpy.zeros((dimensions or FLAGS.num_agents, 2)),
+ ))
+
+ MODULE_VELOCITY = 1.0
+ ROBOT_ANGULAR_VELOCITY = 0.5
+
+ def _random_angle(self, rng: PRNGKey, dimensions=None):
+ return jax.random.uniform(rng, (dimensions or FLAGS.num_agents, 1),
+ minval=-0.1 * jax.numpy.pi,
+ maxval=0.1 * jax.numpy.pi)
+
+ def _random_module_velocity(self, rng: PRNGKey, dimensions=None):
+ return jax.random.uniform(rng, (dimensions or FLAGS.num_agents, 1),
+ minval=-self.MODULE_VELOCITY,
+ maxval=self.MODULE_VELOCITY)
+
+ def _random_robot_velocity(self, rng: PRNGKey, dimensions=None):
+ return jax.random.uniform(rng, (dimensions or FLAGS.num_agents, 2),
+ minval=0.9,
+ maxval=1.1)
+
+ def _random_robot_angular_velocity(self, rng: PRNGKey, dimensions=None):
+ return jax.random.uniform(rng, (dimensions or FLAGS.num_agents, 1),
+ minval=-self.ROBOT_ANGULAR_VELOCITY,
+ maxval=self.ROBOT_ANGULAR_VELOCITY)
diff --git a/frc971/control_loops/swerve/velocity_controller/plot.py b/frc971/control_loops/swerve/velocity_controller/plot.py
new file mode 100644
index 0000000..502a3e2
--- /dev/null
+++ b/frc971/control_loops/swerve/velocity_controller/plot.py
@@ -0,0 +1,430 @@
+#!/usr/bin/env python3
+
+import os
+
+os.environ['XLA_FLAGS'] = ' '.join([
+ # Teach it where to find CUDA
+ '--xla_gpu_cuda_data_dir=/usr/lib/cuda',
+ # Use up to 20 cores
+ '--xla_force_host_platform_device_count=20',
+ # Dump XLA to /tmp/foo to aid debugging
+ #'--xla_dump_to=/tmp/foo',
+ #'--xla_gpu_enable_command_buffer='
+])
+
+os.environ['JAX_PLATFORMS'] = 'cpu'
+# Don't pre-allocate memory
+os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
+
+import absl
+from absl import logging
+from matplotlib.animation import FuncAnimation
+import matplotlib
+import numpy
+import scipy
+import time
+
+matplotlib.use("gtk3agg")
+
+from matplotlib import pylab
+from matplotlib import pyplot
+from flax.training import checkpoints
+import tensorflow as tf
+from jax.experimental.ode import odeint
+import threading
+import collections
+
+from frc971.control_loops.swerve.velocity_controller.model import *
+from frc971.control_loops.swerve.velocity_controller.physics import *
+
+# Container for the data.
+Data = collections.namedtuple('Data', [
+ 't', 'X', 'U', 'logp_pi', 'cost', 'q1_grid', 'q2_grid', 'q_grid',
+ 'target_q_grid', 'pi_grid_U', 'grid_X', 'grid_Y', 'reward', 'step',
+ 'rewards'
+])
+
+FLAGS = absl.flags.FLAGS
+
+absl.flags.DEFINE_string('workdir', None, 'Directory to store model data.')
+
+absl.flags.DEFINE_integer(
+ 'horizon',
+ default=100,
+ help='MPC horizon',
+)
+
+numpy.set_printoptions(linewidth=200, )
+
+absl.flags.DEFINE_float(
+ 'alpha',
+ default=0.2,
+ help='Entropy. If negative, automatically solve for it.',
+)
+
+
+def restore_checkpoint(state: TrainState, workdir: str):
+ return checkpoints.restore_checkpoint(workdir, state)
+
+
+dt = 0.005
+
+
+def X0Full():
+ module_theta = 0.0
+ module_omega = 0.0
+ theta = 0.0
+ vx = 1.0
+ dtheta = 0.05
+ vy = 0.0
+ drive_omega = jax.numpy.hypot(vx, vy) / jax_dynamics.WHEEL_RADIUS
+ omega = 0.0
+ return jax.numpy.array([
+ module_theta + dtheta,
+ 0.0,
+ module_omega,
+ drive_omega,
+ module_theta + dtheta,
+ 0.0,
+ module_omega,
+ drive_omega,
+ module_theta - dtheta,
+ 0.0,
+ module_omega,
+ drive_omega,
+ module_theta - dtheta,
+ 0.0,
+ module_omega,
+ drive_omega,
+ 0.0,
+ 0.0,
+ theta,
+ vx,
+ vy,
+ omega,
+ 0.0,
+ 0.0,
+ 0.0,
+ ])
+
+
+def generate_data(step=None):
+ grid_X = numpy.arange(-1, 1, 0.1)
+ grid_Y = numpy.arange(-10, 10, 0.1)
+ grid_X, grid_Y = numpy.meshgrid(grid_X, grid_Y)
+ grid_X = jax.numpy.array(grid_X)
+ grid_Y = jax.numpy.array(grid_Y)
+ # Load the training state.
+ physics_constants = jax_dynamics.Coefficients()
+ problem = physics.SwerveProblem(physics_constants)
+
+ rng = jax.random.key(0)
+ rng, init_rng = jax.random.split(rng)
+
+ state = create_train_state(init_rng,
+ problem,
+ FLAGS.q_learning_rate,
+ FLAGS.pi_learning_rate,
+ alpha_learning_rate=0.001)
+
+ state = restore_checkpoint(state, FLAGS.workdir)
+
+ X = X0Full()
+ X_lqr = X.copy()
+ goal = jax.numpy.array([1.0, 0.0, 0.0])
+
+ logging.info('X: %s', X)
+ logging.info('goal: %s', goal)
+ logging.debug('params: %s', state.params)
+
+ # Now simulate the robot, accumulating up things to plot.
+ def loop(i, val):
+ X, data, params = val
+ t = data.t.at[i].set(i * problem.dt)
+
+ U, logp_pi, std = state.pi_apply(
+ rng,
+ params,
+ observation=state.problem.unwrap_angles(
+ jax_dynamics.to_velocity_state(X)),
+ R=goal,
+ deterministic=True)
+
+ logp_pi = logp_pi * jax.numpy.exp(params['logalpha'])
+
+ jax.debug.print('mu: {mu} std: {std}', mu=U, std=std)
+
+ step_reward = state.problem.q_reward(jax_dynamics.to_velocity_state(X),
+ U, goal)
+ reward = data.reward + step_reward
+
+ cost = jax.numpy.minimum(
+ state.q1_apply(params,
+ observation=state.problem.unwrap_angles(
+ jax_dynamics.to_velocity_state(X)),
+ R=goal,
+ action=U),
+ state.q2_apply(params,
+ observation=state.problem.unwrap_angles(
+ jax_dynamics.to_velocity_state(X)),
+ R=goal,
+ action=U))
+
+ U = U * problem.action_limit
+ U_plot = data.U.at[i, :].set(U)
+ rewards = data.rewards.at[i, :].set(step_reward)
+ X_plot = data.X.at[i, :].set(X)
+ cost_plot = data.cost.at[i, :].set(cost)
+ logp_pi_plot = data.logp_pi.at[i, :].set(logp_pi)
+
+ # TODO(austin): I'd really like to visualize the slip angle per wheel.
+ # Maybe also the force deviation, etc.
+ # I think that would help enormously in figuring out how good a specific solution is.
+
+ def fn(X, t):
+ return jax_dynamics.full_dynamics(problem.coefficients, X,
+ U).flatten()
+
+ X = odeint(fn, X, jax.numpy.array([0, problem.dt]))
+
+ return X[1, :], data._replace(
+ t=t,
+ U=U_plot,
+ X=X_plot,
+ logp_pi=logp_pi_plot,
+ cost=cost_plot,
+ reward=reward,
+ rewards=rewards,
+ ), params
+
+ # Do it.
+ @jax.jit
+ def integrate(data, X, params):
+ return jax.lax.fori_loop(0, FLAGS.horizon, loop, (X, data, params))
+
+ X, data, params = integrate(
+ Data(
+ t=jax.numpy.zeros((FLAGS.horizon, )),
+ X=jax.numpy.zeros((FLAGS.horizon, jax_dynamics.NUM_STATES)),
+ U=jax.numpy.zeros((FLAGS.horizon, state.problem.num_outputs)),
+ logp_pi=jax.numpy.zeros((FLAGS.horizon, 1)),
+ rewards=jax.numpy.zeros((FLAGS.horizon, 1)),
+ cost=jax.numpy.zeros((FLAGS.horizon, 1)),
+ q1_grid=jax.numpy.zeros(grid_X.shape),
+ q2_grid=jax.numpy.zeros(grid_X.shape),
+ q_grid=jax.numpy.zeros(grid_X.shape),
+ target_q_grid=jax.numpy.zeros(grid_X.shape),
+ pi_grid_U=jax.numpy.zeros(grid_X.shape),
+ grid_X=grid_X,
+ grid_Y=grid_Y,
+ reward=0.0,
+ step=state.step,
+ ), X, state.params)
+
+ logging.info('Reward: %s', float(data.reward))
+
+ # Convert back to numpy for plotting.
+ return Data(
+ t=numpy.array(data.t),
+ X=numpy.array(data.X),
+ U=numpy.array(data.U),
+ logp_pi=numpy.array(data.logp_pi),
+ cost=numpy.array(data.cost),
+ q1_grid=numpy.array(data.q1_grid),
+ q2_grid=numpy.array(data.q2_grid),
+ q_grid=numpy.array(data.q_grid),
+ target_q_grid=numpy.array(data.target_q_grid),
+ pi_grid_U=numpy.array(data.pi_grid_U),
+ grid_X=numpy.array(data.grid_X),
+ grid_Y=numpy.array(data.grid_Y),
+ rewards=numpy.array(data.rewards),
+ reward=float(data.reward),
+ step=data.step,
+ )
+
+
+class Plotter(object):
+
+ def __init__(self, data):
+ # Make all the plots and axis.
+ self.fig0, self.axs0 = pylab.subplots(3)
+ self.fig0.supxlabel('Seconds')
+
+ self.vx, = self.axs0[0].plot([], [], label="vx")
+ self.vy, = self.axs0[0].plot([], [], label="vy")
+ self.omega, = self.axs0[0].plot([], [], label="omega")
+ self.axs0[0].set_ylabel('Velocity')
+ self.axs0[0].legend()
+ self.axs0[0].grid()
+
+ self.steer0, = self.axs0[1].plot([], [], label="Steer0")
+ self.steer1, = self.axs0[1].plot([], [], label="Steer1")
+ self.steer2, = self.axs0[1].plot([], [], label="Steer2")
+ self.steer3, = self.axs0[1].plot([], [], label="Steer3")
+ self.axs0[1].set_ylabel('Amps')
+ self.axs0[1].legend()
+ self.axs0[1].grid()
+
+ self.drive0, = self.axs0[2].plot([], [], label="Drive0")
+ self.drive1, = self.axs0[2].plot([], [], label="Drive1")
+ self.drive2, = self.axs0[2].plot([], [], label="Drive2")
+ self.drive3, = self.axs0[2].plot([], [], label="Drive3")
+ self.axs0[2].set_ylabel('Amps')
+ self.axs0[2].legend()
+ self.axs0[2].grid()
+
+ self.fig1, self.axs1 = pylab.subplots(3)
+ self.fig1.supxlabel('Seconds')
+
+ self.theta0, = self.axs1[0].plot([], [], label='steer position0')
+ self.theta1, = self.axs1[0].plot([], [], label='steer position1')
+ self.theta2, = self.axs1[0].plot([], [], label='steer position2')
+ self.theta3, = self.axs1[0].plot([], [], label='steer position3')
+ self.axs1[0].set_ylabel('Radians')
+ self.axs1[0].legend()
+ self.omega0, = self.axs1[1].plot([], [], label='steer velocity0')
+ self.omega1, = self.axs1[1].plot([], [], label='steer velocity1')
+ self.omega2, = self.axs1[1].plot([], [], label='steer velocity2')
+ self.omega3, = self.axs1[1].plot([], [], label='steer velocity3')
+ self.axs1[1].set_ylabel('Radians/second')
+ self.axs1[1].legend()
+
+ self.logp_axis = self.axs1[2].twinx()
+ self.cost, = self.axs1[2].plot([], [], label='cost')
+ self.reward, = self.axs1[2].plot([], [], label='reward')
+ self.axs1[2].set_ylabel('Radians/second')
+ self.axs1[2].legend()
+
+ self.logp_pi, = self.logp_axis.plot([], [],
+ label='logp_pi',
+ color='C2')
+ self.logp_axis.set_ylabel('log(liklihood)*alpha')
+ self.logp_axis.legend()
+
+ self.last_robot_step = 0
+ self.last_steer_step = 0
+
+ def update_robot_plot(self, data):
+ if data.step is not None and data.step == self.last_robot_step:
+ return
+ self.last_robot_step = data.step
+ logging.info('Updating robot plots')
+ self.fig0.suptitle(f'Step {data.step}')
+
+ self.vx.set_data(data.t, data.X[:, jax_dynamics.STATE_VX])
+ self.vy.set_data(data.t, data.X[:, jax_dynamics.STATE_VY])
+ self.omega.set_data(data.t, data.X[:, jax_dynamics.STATE_OMEGA])
+
+ self.axs0[0].relim()
+ self.axs0[0].autoscale_view()
+
+ self.steer0.set_data(data.t, data.U[:, 0])
+ self.steer1.set_data(data.t, data.U[:, 2])
+ self.steer2.set_data(data.t, data.U[:, 4])
+ self.steer3.set_data(data.t, data.U[:, 6])
+ self.axs0[1].relim()
+ self.axs0[1].autoscale_view()
+
+ self.drive0.set_data(data.t, data.U[:, 1])
+ self.drive1.set_data(data.t, data.U[:, 3])
+ self.drive2.set_data(data.t, data.U[:, 5])
+ self.drive3.set_data(data.t, data.U[:, 7])
+ self.axs0[2].relim()
+ self.axs0[2].autoscale_view()
+
+ return (self.vx, self.vy, self.omega, self.steer0, self.steer1,
+ self.steer2, self.steer3, self.drive0, self.drive1,
+ self.drive2, self.drive3)
+
+ def update_steer_plot(self, data):
+ if data.step == self.last_steer_step:
+ return
+ self.last_steer_step = data.step
+ logging.info('Updating steer plots')
+ self.fig1.suptitle(f'Step {data.step}')
+
+ self.theta0.set_data(data.t, data.X[:, jax_dynamics.STATE_THETAS0])
+ self.theta1.set_data(data.t, data.X[:, jax_dynamics.STATE_THETAS1])
+ self.theta2.set_data(data.t, data.X[:, jax_dynamics.STATE_THETAS2])
+ self.theta3.set_data(data.t, data.X[:, jax_dynamics.STATE_THETAS3])
+ self.axs1[0].relim()
+ self.axs1[0].autoscale_view()
+
+ self.omega0.set_data(data.t, data.X[:, jax_dynamics.STATE_OMEGAS0])
+ self.omega1.set_data(data.t, data.X[:, jax_dynamics.STATE_OMEGAS1])
+ self.omega2.set_data(data.t, data.X[:, jax_dynamics.STATE_OMEGAS2])
+ self.omega3.set_data(data.t, data.X[:, jax_dynamics.STATE_OMEGAS3])
+ self.axs1[1].relim()
+ self.axs1[1].autoscale_view()
+
+ self.cost.set_data(data.t, data.cost)
+ self.reward.set_data(data.t, data.rewards)
+ self.logp_pi.set_data(data.t, data.logp_pi)
+ self.axs1[2].relim()
+ self.axs1[2].autoscale_view()
+ self.logp_axis.relim()
+ self.logp_axis.autoscale_view()
+
+ return (self.theta0, self.theta1, self.theta2, self.theta3,
+ self.omega0, self.omega1, self.omega2, self.omega3, self.cost,
+ self.logp_pi, self.reward)
+
+
+def main(argv):
+ if len(argv) > 1:
+ raise absl.app.UsageError('Too many command-line arguments.')
+
+ tf.config.experimental.set_visible_devices([], 'GPU')
+
+ lock = threading.Lock()
+
+ # Load data.
+ data = generate_data()
+
+ plotter = Plotter(data)
+
+ # Event for shutting down the thread.
+ shutdown = threading.Event()
+
+ # Thread to grab new data periodically.
+ def do_update():
+ while True:
+ nonlocal data
+
+ my_data = generate_data(data.step)
+
+ if my_data is not None:
+ with lock:
+ data = my_data
+
+ if shutdown.wait(timeout=3):
+ return
+
+ update_thread = threading.Thread(target=do_update)
+ update_thread.start()
+
+ # Now, update each of the plots every second with the new data.
+ def update0(frame):
+ with lock:
+ my_data = data
+
+ return plotter.update_robot_plot(my_data)
+
+ def update1(frame):
+ with lock:
+ my_data = data
+
+ return plotter.update_steer_plot(my_data)
+
+ animation0 = FuncAnimation(plotter.fig0, update0, interval=1000)
+ animation1 = FuncAnimation(plotter.fig1, update1, interval=1000)
+
+ pyplot.show()
+
+ shutdown.set()
+ update_thread.join()
+
+
+if __name__ == '__main__':
+ absl.flags.mark_flags_as_required(['workdir'])
+ absl.app.run(main)
diff --git a/frc971/control_loops/swerve/velocity_controller/train.py b/frc971/control_loops/swerve/velocity_controller/train.py
index cf54bc4..34d904a 100644
--- a/frc971/control_loops/swerve/velocity_controller/train.py
+++ b/frc971/control_loops/swerve/velocity_controller/train.py
@@ -1,3 +1,6 @@
+# Machine learning based on Soft Actor Critic(SAC) which was initially proposed in https://arxiv.org/pdf/1801.01290.
+# Our implementation was heavily based on OpenAI's spinning up reference implementation https://spinningup.openai.com/en/latest/algorithms/sac.html.
+
import absl
import time
import collections
@@ -35,7 +38,7 @@
)
absl.flags.DEFINE_integer(
- 'start_steps',
+ 'random_sample_steps',
default=10000,
help='Number of steps to randomly sample before using the policy',
)
@@ -76,6 +79,13 @@
help='If true, explode on any NaNs found, and print them.',
)
+absl.flags.DEFINE_bool(
+ 'maximum_entropy_q',
+ default=True,
+ help=
+ 'If false, do not add the maximum entropy term to the bellman backup for Q.',
+)
+
def save_checkpoint(state: TrainState, workdir: str):
"""Saves a checkpoint in the workdir."""
@@ -123,10 +133,10 @@
R = data['goals']
# Compute the ending actions from the current network.
- action2, logp_pi2, _, _ = state.pi_apply(rng=rng,
- params=params,
- observation=observations2,
- R=R)
+ action2, logp_pi2, _ = state.pi_apply(rng=rng,
+ params=params,
+ observation=observations2,
+ R=R)
# Compute target network Q values
q1_pi_target = state.q1_apply(state.target_params,
@@ -142,8 +152,13 @@
alpha = jax.numpy.exp(params['logalpha'])
# Now we can compute the Bellman backup
- bellman_backup = jax.lax.stop_gradient(rewards + FLAGS.gamma *
- (q_pi_target - alpha * logp_pi2))
+ # Max entropy SAC is based on https://arxiv.org/pdf/1812.05905.
+ if FLAGS.maximum_entropy_q:
+ bellman_backup = jax.lax.stop_gradient(
+ rewards + FLAGS.gamma * (q_pi_target - alpha * logp_pi2))
+ else:
+ bellman_backup = jax.lax.stop_gradient(rewards +
+ FLAGS.gamma * q_pi_target)
# Compute the starting Q values from the Q network being optimized.
q1 = state.q1_apply(params, observation=observations1, R=R, action=actions)
@@ -156,19 +171,6 @@
@jax.jit
-def compute_batched_loss_q(state: TrainState, rng: PRNGKey, params,
- data: ArrayLike):
-
- def bound_compute_loss_q(rng, data):
- return compute_loss_q(state, rng, params, data)
-
- return jax.vmap(bound_compute_loss_q)(
- jax.random.split(rng, FLAGS.num_agents),
- data,
- ).mean()
-
-
-@jax.jit
def compute_loss_pi(state: TrainState, rng: PRNGKey, params, data: ArrayLike):
"""Computes the Soft Actor-Critic loss for pi."""
observations1 = data['observations1']
@@ -176,10 +178,10 @@
# TODO(austin): We've got differentiable policy and differentiable physics. Can we use those here? Have Q learn the future, not the current step?
# Compute the action
- pi, logp_pi, _, _ = state.pi_apply(rng=rng,
- params=params,
- observation=observations1,
- R=R)
+ pi, logp_pi, _ = state.pi_apply(rng=rng,
+ params=params,
+ observation=observations1,
+ R=R)
q1_pi = state.q1_apply(jax.lax.stop_gradient(params),
observation=observations1,
R=R,
@@ -199,6 +201,32 @@
@jax.jit
+def compute_loss_alpha(state: TrainState, rng: PRNGKey, params,
+ data: ArrayLike):
+ """Computes the Soft Actor-Critic loss for alpha."""
+ observations1 = data['observations1']
+ R = data['goals']
+ pi, logp_pi, _ = jax.lax.stop_gradient(
+ state.pi_apply(rng=rng, params=params, observation=observations1, R=R))
+
+ return (-jax.numpy.exp(params['logalpha']) *
+ (logp_pi + state.target_entropy)).mean(), logp_pi.mean()
+
+
+@jax.jit
+def compute_batched_loss_q(state: TrainState, rng: PRNGKey, params,
+ data: ArrayLike):
+
+ def bound_compute_loss_q(rng, data):
+ return compute_loss_q(state, rng, params, data)
+
+ return jax.vmap(bound_compute_loss_q)(
+ jax.random.split(rng, FLAGS.num_agents),
+ data,
+ ).mean()
+
+
+@jax.jit
def compute_batched_loss_pi(state: TrainState, rng: PRNGKey, params,
data: ArrayLike):
@@ -212,33 +240,21 @@
@jax.jit
-def compute_loss_alpha(state: TrainState, rng: PRNGKey, params,
- data: ArrayLike):
- """Computes the Soft Actor-Critic loss for alpha."""
- observations1 = data['observations1']
- R = data['goals']
- pi, logp_pi, _, _ = jax.lax.stop_gradient(
- state.pi_apply(rng=rng, params=params, R=R, observation=observations1))
-
- return (-jax.numpy.exp(params['logalpha']) *
- (logp_pi + state.target_entropy)).mean()
-
-
-@jax.jit
def compute_batched_loss_alpha(state: TrainState, rng: PRNGKey, params,
data: ArrayLike):
def bound_compute_loss_alpha(rng, data):
return compute_loss_alpha(state, rng, params, data)
- return jax.vmap(bound_compute_loss_alpha)(
+ loss, entropy = jax.vmap(bound_compute_loss_alpha)(
jax.random.split(rng, FLAGS.num_agents),
data,
- ).mean()
+ )
+ return (loss.mean(), entropy.mean())
@jax.jit
-def train_step(state: TrainState, data, action_data, update_rng: PRNGKey,
+def train_step(state: TrainState, data, update_rng: PRNGKey,
step: int) -> TrainState:
"""Updates the parameters for Q, Pi, target Q, and alpha."""
update_rng, q_grad_rng = jax.random.split(update_rng)
@@ -253,16 +269,11 @@
state = state.q_apply_gradients(step=step, grads=q_grads)
- update_rng, pi_grad_rng = jax.random.split(update_rng)
-
# Update pi
+ update_rng, pi_grad_rng = jax.random.split(update_rng)
pi_grad_fn = jax.value_and_grad(lambda params: compute_batched_loss_pi(
- state, pi_grad_rng, params, action_data))
+ state, pi_grad_rng, params, data))
pi_loss, pi_grads = pi_grad_fn(state.params)
-
- print_nan(step, pi_loss)
- print_nan(step, pi_grads)
-
state = state.pi_apply_gradients(step=step, grads=pi_grads)
update_rng, alpha_grad_rng = jax.random.split(update_rng)
@@ -271,15 +282,18 @@
# Update alpha
alpha_grad_fn = jax.value_and_grad(
lambda params: compute_batched_loss_alpha(state, alpha_grad_rng,
- params, data))
- alpha_loss, alpha_grads = alpha_grad_fn(state.params)
+ params, data),
+ has_aux=True,
+ )
+ (alpha_loss, entropy), alpha_grads = alpha_grad_fn(state.params)
print_nan(step, alpha_loss)
print_nan(step, alpha_grads)
state = state.alpha_apply_gradients(step=step, grads=alpha_grads)
else:
+ entropy = 0.0
alpha_loss = 0.0
- return state, q_loss, pi_loss, alpha_loss
+ return state, q_loss, pi_loss, alpha_loss, entropy
@jax.jit
@@ -304,20 +318,22 @@
def true_fn(i):
# We are at the beginning of the process, pick a random action.
- return state.problem.random_actions(action_rng, FLAGS.num_agents)
+ return state.problem.random_actions(action_rng,
+ X=observation,
+ goal=R,
+ dimensions=FLAGS.num_agents)
def false_fn(i):
# We are past the beginning of the process, use the trained network.
- pi_action, logp_pi, std, random_sample = state.pi_apply(
- rng=action_rng,
- params=state.params,
- observation=observation,
- R=R,
- deterministic=False)
+ pi_action, _, _ = state.pi_apply(rng=action_rng,
+ params=state.params,
+ observation=observation,
+ R=R,
+ deterministic=False)
return pi_action
pi_action = jax.lax.cond(
- step <= FLAGS.start_steps,
+ step <= FLAGS.random_sample_steps,
true_fn,
false_fn,
i,
@@ -328,11 +344,9 @@
lambda o, pi: state.problem.integrate_dynamics(o, pi),
in_axes=(0, 0))(observation, pi_action)
- # Soft Actor-Critic is designed to maximize reward. LQR minimizes
- # cost. There is nothing which assumes anything about the sign of
- # the reward, so use the negative of the cost.
- reward = -jax.vmap(state.problem.cost)(
- X=observation2, U=pi_action, goal=R)
+ reward = jax.vmap(state.problem.reward)(X=observation2,
+ U=pi_action,
+ goal=R)
replay_buffer_state = state.replay_buffer.add(
replay_buffer_state, {
@@ -357,10 +371,8 @@
step: int):
rng, sample_rng = jax.random.split(rng)
- action_data = state.replay_buffer.sample(replay_buffer_state, sample_rng)
-
def update_iteration(i, val):
- rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, action_data = val
+ rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, entropy = val
rng, sample_rng, update_rng = jax.random.split(rng, 3)
batch = state.replay_buffer.sample(replay_buffer_state, sample_rng)
@@ -368,22 +380,18 @@
print_nan(i, replay_buffer_state)
print_nan(i, batch)
- state, q_loss, pi_loss, alpha_loss = train_step(
- state,
- data=batch.experience,
- action_data=batch.experience,
- update_rng=update_rng,
- step=i)
+ state, q_loss, pi_loss, alpha_loss, entropy = train_step(
+ state, data=batch.experience, update_rng=update_rng, step=i)
- return rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, action_data
+ return rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, entropy
- rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, action_data = jax.lax.fori_loop(
+ rng, state, q_loss, pi_loss, alpha_loss, replay_buffer_state, entropy = jax.lax.fori_loop(
step, step + FLAGS.horizon + 1, update_iteration,
- (rng, state, 0.0, 0.0, 0.0, replay_buffer_state, action_data))
+ (rng, state, 0.0, 0.0, 0.0, replay_buffer_state, 0))
state = state.target_apply_gradients(step=state.step)
- return rng, state, q_loss, pi_loss, alpha_loss
+ return rng, state, q_loss, pi_loss, alpha_loss, entropy
def train(workdir: str, problem: Problem) -> train_state.TrainState:
@@ -396,6 +404,8 @@
run['hparams'] = {
'q_learning_rate': FLAGS.q_learning_rate,
'pi_learning_rate': FLAGS.pi_learning_rate,
+ 'alpha_learning_rate': FLAGS.alpha_learning_rate,
+ 'random_sample_steps': FLAGS.random_sample_steps,
'batch_size': FLAGS.batch_size,
'horizon': FLAGS.horizon,
'warmup_steps': FLAGS.warmup_steps,
@@ -422,11 +432,11 @@
problem,
q_learning_rate=q_learning_rate,
pi_learning_rate=pi_learning_rate,
+ alpha_learning_rate=FLAGS.alpha_learning_rate,
)
state = restore_checkpoint(state, workdir)
- state_sharding = nn.get_sharding(state, state.mesh)
- logging.info(state_sharding)
+ logging.debug(nn.get_sharding(state, state.mesh))
replay_buffer_state = state.replay_buffer.init({
'observations1':
@@ -438,12 +448,10 @@
'rewards':
jax.numpy.zeros((1, )),
'goals':
- jax.numpy.zeros((problem.num_states, )),
+ jax.numpy.zeros((problem.num_goals, )),
})
- replay_buffer_state_sharding = nn.get_sharding(replay_buffer_state,
- state.mesh)
- logging.info(replay_buffer_state_sharding)
+ logging.debug(nn.get_sharding(replay_buffer_state, state.mesh))
# Number of gradients to accumulate before doing decent.
update_after = FLAGS.batch_size // FLAGS.num_agents
@@ -461,24 +469,25 @@
)
def nop(rng, state, replay_buffer_state, step):
- return rng, state.update_step(step=step), 0.0, 0.0, 0.0
+ return rng, state.update_step(step=step), 0.0, 0.0, 0.0, 0.0
# Train
- rng, state, q_loss, pi_loss, alpha_loss = jax.lax.cond(
+ rng, state, q_loss, pi_loss, alpha_loss, entropy = jax.lax.cond(
step >= update_after, update_gradients, nop, rng, state,
replay_buffer_state, step)
- return state, replay_buffer_state, rng, q_loss, pi_loss, alpha_loss
+ return state, replay_buffer_state, rng, q_loss, pi_loss, alpha_loss, entropy
+ last_time = time.time()
for step in range(0, FLAGS.steps, FLAGS.horizon):
- state, replay_buffer_state, rng, q_loss, pi_loss, alpha_loss = train_loop(
+ state, replay_buffer_state, rng, q_loss, pi_loss, alpha_loss, entropy = train_loop(
state, replay_buffer_state, rng, step)
if FLAGS.debug_nan and has_nan(state.params):
logging.fatal('Nan params, aborting')
logging.info(
- 'Step %s: q_loss=%s, pi_loss=%s, alpha_loss=%s, q_learning_rate=%s, pi_learning_rate=%s, alpha=%s',
+ 'Step %s: q_loss=%s, pi_loss=%s, alpha_loss=%s, q_learning_rate=%s, pi_learning_rate=%s, alpha=%s, entropy=%s, random=%s',
step,
q_loss,
pi_loss,
@@ -486,6 +495,8 @@
q_learning_rate(step),
pi_learning_rate(step),
jax.numpy.exp(state.params['logalpha']),
+ entropy,
+ step <= FLAGS.random_sample_steps,
)
run.track(
@@ -493,12 +504,14 @@
'q_loss': float(q_loss),
'pi_loss': float(pi_loss),
'alpha_loss': float(alpha_loss),
- 'alpha': float(jax.numpy.exp(state.params['logalpha']))
+ 'alpha': float(jax.numpy.exp(state.params['logalpha'])),
+ 'entropy': entropy,
},
step=step)
- if step % 1000 == 0 and step > update_after:
+ if time.time() > last_time + 3.0 and step > update_after:
# TODO(austin): Simulate a rollout and accumulate the reward. How good are we doing?
save_checkpoint(state, workdir)
+ last_time = time.time()
return state
diff --git a/frc971/imu_fdcan/README.md b/frc971/imu_fdcan/README.md
index 8e82f73..7bf7328 100644
--- a/frc971/imu_fdcan/README.md
+++ b/frc971/imu_fdcan/README.md
@@ -23,7 +23,8 @@
* The main code lives in [`Dual_IMU/Core/Src`](/Dual_IMU/Core/Src/). Make sure your changes happen inside sections marked `/* USER CODE BEGIN ... */` `/* USER CODE END ... */`. Code outside these markers will be overwritten by CubeIDE when generating code after changes to the `.ioc` file.
3) Build + Run:
* Option 1: Open CubeIDE GUI to build, debug, or run.
- * Option 2:
+ <!-- TODO(sindy): fix this build script -->
+ * Option 2 (DO NOT USE. NOT SAFE)
1) SSH onto the build server.
2) Run `bazel build -c opt --config=cortex-m4f-imu //frc971/imu_fdcan/Dual_IMU/Core:main.elf`. The output .elf file should be in bazel-bin/frc971/imu_fdcan/Dual_IMU/Core.
3) (If deploying code locally) Move file to local directory. For example: `scp <username>@build.frc971.org:<path/to/main.elf> <local/path/to/save/file/`. A good spot to put this locally is ./Dual_IMU/Debug/.
diff --git a/frc971/orin/set_orin_clock.sh b/frc971/orin/set_orin_clock.sh
index bc9dd8f..fd615c5 100755
--- a/frc971/orin/set_orin_clock.sh
+++ b/frc971/orin/set_orin_clock.sh
@@ -12,7 +12,7 @@
for orin in $ORIN_LIST; do
echo "========================================================"
- echo "Setting clock for ${ROBOT_PREFIX}71.10${orin}"
+ echo "Setting clock for 10.${ROBOT_PREFIX}.71.10${orin}"
echo "========================================================"
current_time=`sudo hwclock`
IFS="."
diff --git a/frc971/vision/image_logger.cc b/frc971/vision/image_logger.cc
index cb8fc4e..01ceaf4 100644
--- a/frc971/vision/image_logger.cc
+++ b/frc971/vision/image_logger.cc
@@ -75,6 +75,8 @@
});
}
+ LOG(INFO) << "Starting image_logger; will wait on joystick enabled to start "
+ "logging";
event_loop.OnRun([]() {
errno = 0;
setpriority(PRIO_PROCESS, 0, -20);
diff --git a/frc971/wpilib/talonfx.cc b/frc971/wpilib/talonfx.cc
index 4338cfd..0da7603 100644
--- a/frc971/wpilib/talonfx.cc
+++ b/frc971/wpilib/talonfx.cc
@@ -58,11 +58,12 @@
void TalonFX::WriteConfigs() {
ctre::phoenix6::configs::CurrentLimitsConfigs current_limits;
- current_limits.StatorCurrentLimit = stator_current_limit_;
+ current_limits.StatorCurrentLimit =
+ units::current::ampere_t{stator_current_limit_};
current_limits.StatorCurrentLimitEnable = true;
- current_limits.SupplyCurrentLimit = supply_current_limit_;
+ current_limits.SupplyCurrentLimit =
+ units::current::ampere_t{supply_current_limit_};
current_limits.SupplyCurrentLimitEnable = true;
- current_limits.SupplyTimeThreshold = 0.0;
ctre::phoenix6::configs::MotorOutputConfigs output_configs;
output_configs.NeutralMode = neutral_mode_;
diff --git a/frc971/zeroing/continuous_absolute_encoder.cc b/frc971/zeroing/continuous_absolute_encoder.cc
index 3381833..733d46d 100644
--- a/frc971/zeroing/continuous_absolute_encoder.cc
+++ b/frc971/zeroing/continuous_absolute_encoder.cc
@@ -165,4 +165,14 @@
return builder.Finish();
}
+void ContinuousAbsoluteEncoderZeroingEstimator::GetEstimatorState(
+ AbsoluteEncoderEstimatorStateStatic *fbs) const {
+ errors_.ToStaticFlatbuffer(fbs->add_errors());
+
+ fbs->set_error(error_);
+ fbs->set_zeroed(zeroed_);
+ fbs->set_position(position_);
+ fbs->set_absolute_position(filtered_absolute_encoder_);
+}
+
} // namespace frc971::zeroing
diff --git a/frc971/zeroing/continuous_absolute_encoder.h b/frc971/zeroing/continuous_absolute_encoder.h
index 4994280..5e700ee 100644
--- a/frc971/zeroing/continuous_absolute_encoder.h
+++ b/frc971/zeroing/continuous_absolute_encoder.h
@@ -47,6 +47,8 @@
virtual flatbuffers::Offset<State> GetEstimatorState(
flatbuffers::FlatBufferBuilder *fbb) const override;
+ void GetEstimatorState(AbsoluteEncoderEstimatorStateStatic *fbs) const;
+
private:
struct PositionStruct {
PositionStruct(const AbsolutePosition &position_buffer)
diff --git a/frc971/zeroing/zeroing.h b/frc971/zeroing/zeroing.h
index 68ef38f..01c2c61 100644
--- a/frc971/zeroing/zeroing.h
+++ b/frc971/zeroing/zeroing.h
@@ -10,7 +10,7 @@
#include "flatbuffers/flatbuffers.h"
#include "frc971/constants.h"
-#include "frc971/control_loops/control_loops_generated.h"
+#include "frc971/control_loops/control_loops_static.h"
// TODO(pschrader): Flag an error if encoder index pulse is not n revolutions
// away from the last one (i.e. got extra counts from noise, etc..)
diff --git a/scouting/scouting_test.cy.js b/scouting/scouting_test.cy.js
index 990e69f..7157880 100644
--- a/scouting/scouting_test.cy.js
+++ b/scouting/scouting_test.cy.js
@@ -119,7 +119,7 @@
' Ended Match; stageType: kHARMONY, trapNote: false, spotlight: false '
);
// Ensure that the penalties action is only submitted once.
- cy.get('#review_data li').contains('Penalties').should('have.length', 1);
+ cy.get('#review_data li:contains("Penalties")').its('length').should('eq', 1);
clickButton('Submit');
headerShouldBe(teamNumber + ' Success ');
diff --git a/y2020/wpilib_interface.cc b/y2020/wpilib_interface.cc
index 2213c16..8c8b480 100644
--- a/y2020/wpilib_interface.cc
+++ b/y2020/wpilib_interface.cc
@@ -123,9 +123,11 @@
void WriteConfigs(ctre::phoenix6::hardware::TalonFX *talon,
double stator_current_limit, double supply_current_limit) {
ctre::phoenix6::configs::CurrentLimitsConfigs current_limits;
- current_limits.StatorCurrentLimit = stator_current_limit;
+ current_limits.StatorCurrentLimit =
+ units::current::ampere_t{stator_current_limit};
current_limits.StatorCurrentLimitEnable = true;
- current_limits.SupplyCurrentLimit = supply_current_limit;
+ current_limits.SupplyCurrentLimit =
+ units::current::ampere_t{supply_current_limit};
current_limits.SupplyCurrentLimitEnable = true;
ctre::phoenix6::configs::TalonFXConfiguration configuration;
@@ -454,7 +456,8 @@
::std::unique_ptr<::ctre::phoenix6::hardware::TalonFX> t) {
climber_falcon_ = ::std::move(t);
ctre::phoenix6::configs::CurrentLimitsConfigs current_limits;
- current_limits.SupplyCurrentLimit = Values::kClimberSupplyCurrentLimit();
+ current_limits.SupplyCurrentLimit =
+ units::current::ampere_t{Values::kClimberSupplyCurrentLimit()};
current_limits.SupplyCurrentLimitEnable = true;
ctre::phoenix6::configs::TalonFXConfiguration configuration;
diff --git a/y2021_bot3/wpilib_interface.cc b/y2021_bot3/wpilib_interface.cc
index 4006736..f3c956e 100644
--- a/y2021_bot3/wpilib_interface.cc
+++ b/y2021_bot3/wpilib_interface.cc
@@ -119,9 +119,11 @@
void WriteConfigs(ctre::phoenix6::hardware::TalonFX *talon,
double stator_current_limit, double supply_current_limit) {
ctre::phoenix6::configs::CurrentLimitsConfigs current_limits;
- current_limits.StatorCurrentLimit = stator_current_limit;
+ current_limits.StatorCurrentLimit =
+ units::current::ampere_t{stator_current_limit};
current_limits.StatorCurrentLimitEnable = true;
- current_limits.SupplyCurrentLimit = supply_current_limit;
+ current_limits.SupplyCurrentLimit =
+ units::current::ampere_t{supply_current_limit};
current_limits.SupplyCurrentLimitEnable = true;
ctre::phoenix6::configs::TalonFXConfiguration configuration;
@@ -249,7 +251,8 @@
::std::unique_ptr<::ctre::phoenix6::hardware::TalonFX> t) {
climber_falcon_ = ::std::move(t);
ctre::phoenix6::configs::CurrentLimitsConfigs current_limits;
- current_limits.SupplyCurrentLimit = Values::kClimberSupplyCurrentLimit();
+ current_limits.SupplyCurrentLimit =
+ units::current::ampere_t{Values::kClimberSupplyCurrentLimit()};
current_limits.SupplyCurrentLimitEnable = true;
ctre::phoenix6::configs::TalonFXConfiguration configuration;
diff --git a/y2022/wpilib_interface.cc b/y2022/wpilib_interface.cc
index 1f82e91..469c93b 100644
--- a/y2022/wpilib_interface.cc
+++ b/y2022/wpilib_interface.cc
@@ -129,9 +129,11 @@
void WriteConfigs(ctre::phoenix6::hardware::TalonFX *talon,
double stator_current_limit, double supply_current_limit) {
ctre::phoenix6::configs::CurrentLimitsConfigs current_limits;
- current_limits.StatorCurrentLimit = stator_current_limit;
+ current_limits.StatorCurrentLimit =
+ units::current::ampere_t{stator_current_limit};
current_limits.StatorCurrentLimitEnable = true;
- current_limits.SupplyCurrentLimit = supply_current_limit;
+ current_limits.SupplyCurrentLimit =
+ units::current::ampere_t{supply_current_limit};
current_limits.SupplyCurrentLimitEnable = true;
ctre::phoenix6::configs::TalonFXConfiguration configuration;
@@ -483,10 +485,10 @@
for (auto &falcon : {catapult_falcon_1_can_, catapult_falcon_2_can_}) {
ctre::phoenix6::configs::CurrentLimitsConfigs current_limits;
current_limits.StatorCurrentLimit =
- Values::kIntakeRollerStatorCurrentLimit();
+ units::current::ampere_t{Values::kIntakeRollerStatorCurrentLimit()};
current_limits.StatorCurrentLimitEnable = true;
current_limits.SupplyCurrentLimit =
- Values::kIntakeRollerSupplyCurrentLimit();
+ units::current::ampere_t{Values::kIntakeRollerSupplyCurrentLimit()};
current_limits.SupplyCurrentLimitEnable = true;
ctre::phoenix6::configs::TalonFXConfiguration configuration;
diff --git a/y2023/wpilib_interface.cc b/y2023/wpilib_interface.cc
index bc14311..536f7bb 100644
--- a/y2023/wpilib_interface.cc
+++ b/y2023/wpilib_interface.cc
@@ -166,11 +166,11 @@
inverted_ = invert;
ctre::phoenix6::configs::CurrentLimitsConfigs current_limits;
- current_limits.StatorCurrentLimit =
- constants::Values::kDrivetrainStatorCurrentLimit();
+ current_limits.StatorCurrentLimit = units::current::ampere_t{
+ constants::Values::kDrivetrainStatorCurrentLimit()};
current_limits.StatorCurrentLimitEnable = true;
- current_limits.SupplyCurrentLimit =
- constants::Values::kDrivetrainSupplyCurrentLimit();
+ current_limits.SupplyCurrentLimit = units::current::ampere_t{
+ constants::Values::kDrivetrainSupplyCurrentLimit()};
current_limits.SupplyCurrentLimitEnable = true;
ctre::phoenix6::configs::MotorOutputConfigs output_configs;
@@ -196,11 +196,11 @@
void WriteRollerConfigs() {
ctre::phoenix6::configs::CurrentLimitsConfigs current_limits;
- current_limits.StatorCurrentLimit =
- constants::Values::kRollerStatorCurrentLimit();
+ current_limits.StatorCurrentLimit = units::current::ampere_t{
+ constants::Values::kRollerStatorCurrentLimit()};
current_limits.StatorCurrentLimitEnable = true;
- current_limits.SupplyCurrentLimit =
- constants::Values::kRollerSupplyCurrentLimit();
+ current_limits.SupplyCurrentLimit = units::current::ampere_t{
+ constants::Values::kRollerSupplyCurrentLimit()};
current_limits.SupplyCurrentLimitEnable = true;
ctre::phoenix6::configs::MotorOutputConfigs output_configs;
diff --git a/y2023_bot3/wpilib_interface.cc b/y2023_bot3/wpilib_interface.cc
index d8813be..0e43e10 100644
--- a/y2023_bot3/wpilib_interface.cc
+++ b/y2023_bot3/wpilib_interface.cc
@@ -147,11 +147,11 @@
inverted_ = invert;
ctre::phoenix6::configs::CurrentLimitsConfigs current_limits;
- current_limits.StatorCurrentLimit =
- constants::Values::kDrivetrainStatorCurrentLimit();
+ current_limits.StatorCurrentLimit = units::current::ampere_t{
+ constants::Values::kDrivetrainStatorCurrentLimit()};
current_limits.StatorCurrentLimitEnable = true;
- current_limits.SupplyCurrentLimit =
- constants::Values::kDrivetrainSupplyCurrentLimit();
+ current_limits.SupplyCurrentLimit = units::current::ampere_t{
+ constants::Values::kDrivetrainSupplyCurrentLimit()};
current_limits.SupplyCurrentLimitEnable = true;
ctre::phoenix6::configs::MotorOutputConfigs output_configs;
diff --git a/y2024/vision/viewer.cc b/y2024/vision/viewer.cc
index c72c08e..c741445 100644
--- a/y2024/vision/viewer.cc
+++ b/y2024/vision/viewer.cc
@@ -86,7 +86,9 @@
frc971::constants::ConstantsFetcher<y2024::Constants> constants_fetcher(
&event_loop);
- CHECK(absl::GetFlag(FLAGS_channel).length() == 8);
+ CHECK(absl::GetFlag(FLAGS_channel).length() == 8)
+ << " channel should be of the form '/cameraN' for viewing images from "
+ "camera N";
int camera_id = std::stoi(absl::GetFlag(FLAGS_channel).substr(7, 1));
const auto *calibration_data = FindCameraCalibration(
constants_fetcher.constants(), event_loop.node()->name()->string_view(),
diff --git a/y2024_swerve/y2024_swerve_roborio.json b/y2024_swerve/y2024_swerve_roborio.json
index 705370d..adf528d 100644
--- a/y2024_swerve/y2024_swerve_roborio.json
+++ b/y2024_swerve/y2024_swerve_roborio.json
@@ -305,8 +305,7 @@
"name": "wpilib_interface",
"executable_name": "wpilib_interface",
"args": [
- "--nodie_on_malloc",
- "--ctre_diag_server"
+ "--nodie_on_malloc"
],
"nodes": [
"roborio"