Prep lqr_plot for plotting swerve
There were a lot of constants which were duplicated, makes more sense to
pull them into the physics to keep the logic cleaner.
Change-Id: I04cede2008dae1035de2f820a3d588e07c1573dd
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/velocity_controller/lqr_plot.py b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
index 8df956d..bd6674c 100644
--- a/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
+++ b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
@@ -75,18 +75,6 @@
return checkpoints.restore_checkpoint(workdir, state)
-# Hard-coded simulation parameters for the turret.
-dt = 0.005
-A = numpy.matrix([[1., 0.00456639], [0., 0.83172142]])
-B = numpy.matrix([[0.00065992], [0.25610763]])
-
-Q = numpy.matrix([[2.77777778, 0.], [0., 0.01]])
-R = numpy.matrix([[0.00694444]])
-
-# Compute the optimal LQR cost + controller.
-F, P = controls.dlqr(A, B, Q, R, optimal_cost_function=True)
-
-
def generate_data(step=None):
grid_X = numpy.arange(-1, 1, 0.1)
grid_Y = numpy.arange(-10, 10, 0.1)
@@ -106,9 +94,6 @@
if step is not None and state.step == step:
return None
- print('F:', F)
- print('P:', P)
-
X = jax.numpy.array([1.0, 0.0])
X_lqr = X.copy()
goal = jax.numpy.array([0.0, 0.0])
@@ -136,7 +121,7 @@
def lqr_cost(X, Y):
x = jax.numpy.array([X, Y]) - goal
- return -x.T @ jax.numpy.array(P) @ x
+ return -x.T @ jax.numpy.array(problem.P) @ x
def compute_q(params, x, y):
X = state.problem.unwrap_angles(jax.numpy.array([x, y]))
@@ -167,7 +152,7 @@
# Now compute the two controller surfaces.
def compute_lqr_U(X, Y):
x = jax.numpy.array([X, Y])
- return (-jax.numpy.array(F.reshape((2, ))) @ x)[0]
+ return (-jax.numpy.array(problem.F.reshape((2, ))) @ x)[0]
def compute_pi_U(X, Y):
x = jax.numpy.array([X, Y])
@@ -186,14 +171,14 @@
# Now simulate the robot, accumulating up things to plot.
def loop(i, val):
X, X_lqr, data, params = val
- t = data.t.at[i].set(i * dt)
+ t = data.t.at[i].set(i * problem.dt)
U, _, _, _ = state.pi_apply(rng,
params,
observation=state.problem.unwrap_angles(X),
R=goal,
deterministic=True)
- U_lqr = F @ (goal - X_lqr)
+ U_lqr = problem.F @ (goal - X_lqr)
cost = jax.numpy.minimum(
state.q1_apply(params,
@@ -212,8 +197,8 @@
cost_plot = data.cost.at[i, :].set(cost)
cost_lqr_plot = data.cost_lqr.at[i, :].set(lqr_cost(*X_lqr))
- X = A @ X + B @ U
- X_lqr = A @ X_lqr + B @ U_lqr
+ X = problem.A @ X + problem.B @ U
+ X_lqr = problem.A @ X_lqr + problem.B @ U_lqr
reward = data.reward - state.problem.cost(X, U, goal)
reward_lqr = data.reward_lqr - state.problem.cost(X_lqr, U_lqr, goal)
diff --git a/frc971/control_loops/swerve/velocity_controller/physics.py b/frc971/control_loops/swerve/velocity_controller/physics.py
index 450b692..d6515fe 100644
--- a/frc971/control_loops/swerve/velocity_controller/physics.py
+++ b/frc971/control_loops/swerve/velocity_controller/physics.py
@@ -1,8 +1,8 @@
-import jax
+import jax, numpy
from functools import partial
-from frc971.control_loops.swerve import dynamics
from absl import logging
-from frc971.control_loops.swerve import jax_dynamics
+from frc971.control_loops.swerve import dynamics, jax_dynamics
+from frc971.control_loops.python import controls
from flax.typing import PRNGKey
@@ -15,11 +15,12 @@
self.num_outputs = num_outputs
self.num_goals = num_goals
self.action_limit = action_limit
+ self.dt = 0.005
def integrate_dynamics(self, X: jax.typing.ArrayLike,
U: jax.typing.ArrayLike):
m = 2 # RK4 steps per interval
- dt = 0.005 / m
+ dt = self.dt / m
def iteration(i, X):
weights = jax.numpy.array([[0.0, 0.5, 0.5, 1.0],
@@ -76,19 +77,29 @@
num_outputs=1,
num_goals=2,
action_limit=30.0)
+ self.A = numpy.matrix([[1., 0.00456639], [0., 0.83172142]])
+ self.B = numpy.matrix([[0.00065992], [0.25610763]])
+
+ self.Q = numpy.matrix([[2.77777778, 0.], [0., 0.01]])
+ self.R = numpy.matrix([[0.00694444]])
+
+ # Compute the optimal LQR cost + controller.
+ self.F, self.P = controls.dlqr(self.A,
+ self.B,
+ self.Q,
+ self.R,
+ optimal_cost_function=True)
def xdot(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike):
- A = jax.numpy.array([[0., 1.], [0., -36.85154548]])
- B = jax.numpy.array([[0.], [56.08534375]])
+ A_continuous = jax.numpy.array([[0., 1.], [0., -36.85154548]])
+ B_continuous = jax.numpy.array([[0.], [56.08534375]])
- return A @ X + B @ U
+ return A_continuous @ X + B_continuous @ U
def cost(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike,
goal: jax.typing.ArrayLike):
- Q = jax.numpy.array([[2.77777778, 0.], [0., 0.01]])
- R = jax.numpy.array([[0.00694444]])
-
- return (X - goal).T @ Q @ (X - goal) + U.T @ R @ U
+ return (X - goal).T @ jax.numpy.array(
+ self.Q) @ (X - goal) + U.T @ jax.numpy.array(self.R) @ U
def random_states(self, rng: PRNGKey, dimensions=None):
rng1, rng2 = jax.random.split(rng)