Provide a goal to our SAC implementation
This sets us up to go to multiple destinations nicely.
Change-Id: Ie54cd037b710776fe8ae14ae395a9776d37f0296
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/velocity_controller/lqr_plot.py b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
index f8e9137..8df956d 100644
--- a/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
+++ b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
@@ -122,6 +122,7 @@
return state.q1_apply(
state.params,
observation=state.problem.unwrap_angles(jax.numpy.array([X, Y])),
+ R=goal,
action=jax.numpy.array([0.]),
)[0]
@@ -129,11 +130,12 @@
return state.q2_apply(
state.params,
observation=state.problem.unwrap_angles(jax.numpy.array([X, Y])),
+ R=goal,
action=jax.numpy.array([0.]),
)[0]
def lqr_cost(X, Y):
- x = jax.numpy.array([X, Y])
+ x = jax.numpy.array([X, Y]) - goal
return -x.T @ jax.numpy.array(P) @ x
def compute_q(params, x, y):
@@ -143,11 +145,13 @@
state.q1_apply(
params,
observation=X,
+ R=goal,
action=jax.numpy.array([0.]),
)[0],
state.q2_apply(
params,
observation=X,
+ R=goal,
action=jax.numpy.array([0.]),
)[0])
@@ -160,10 +164,6 @@
grid_Y)
lqr_cost_grid = jax.vmap(jax.vmap(lqr_cost))(grid_X, grid_Y)
- # TODO(austin): Stuff to figure out:
- # 3: Make it converge faster. Use both GPUs better?
- # 4: Can we feed in a reference position and get it to learn to stabilize there?
-
# Now compute the two controller surfaces.
def compute_lqr_U(X, Y):
x = jax.numpy.array([X, Y])
@@ -174,6 +174,7 @@
U, _, _, _ = state.pi_apply(rng,
state.params,
observation=state.problem.unwrap_angles(x),
+ R=goal,
deterministic=True)
return U[0]
@@ -190,15 +191,18 @@
U, _, _, _ = state.pi_apply(rng,
params,
observation=state.problem.unwrap_angles(X),
+ R=goal,
deterministic=True)
U_lqr = F @ (goal - X_lqr)
cost = jax.numpy.minimum(
state.q1_apply(params,
observation=state.problem.unwrap_angles(X),
+ R=goal,
action=U),
state.q2_apply(params,
observation=state.problem.unwrap_angles(X),
+ R=goal,
action=U))
U_plot = data.U.at[i, :].set(U)
@@ -289,24 +293,32 @@
self.fig0, self.axs0 = pylab.subplots(3)
self.fig0.supxlabel('Seconds')
- self.x, = self.axs0[0].plot([], [], label="x")
- self.v, = self.axs0[0].plot([], [], label="v")
- self.x_lqr, = self.axs0[0].plot([], [], label="x_lqr")
- self.v_lqr, = self.axs0[0].plot([], [], label="v_lqr")
+ self.axs_velocity = self.axs0[0].twinx()
- self.axs0[0].set_ylabel('Velocity')
+ self.x, = self.axs0[0].plot([], [], label="x")
+ self.x_lqr, = self.axs0[0].plot([], [], label="x_lqr")
+
+ self.axs0[0].set_ylabel('Position')
self.axs0[0].legend()
+ self.axs0[0].grid()
+
+ self.v, = self.axs_velocity.plot([], [], label="v", color='C2')
+ self.v_lqr, = self.axs_velocity.plot([], [], label="v_lqr", color='C3')
+ self.axs_velocity.set_ylabel('Velocity')
+ self.axs_velocity.legend()
self.uaxis, = self.axs0[1].plot([], [], label="U")
self.uaxis_lqr, = self.axs0[1].plot([], [], label="U_lqr")
self.axs0[1].set_ylabel('Amps')
self.axs0[1].legend()
+ self.axs0[1].grid()
self.costaxis, = self.axs0[2].plot([], [], label="cost")
self.costlqraxis, = self.axs0[2].plot([], [], label="cost lqr")
self.axs0[2].set_ylabel('Cost')
self.axs0[2].legend()
+ self.axs0[2].grid()
self.costfig = pyplot.figure(figsize=pyplot.figaspect(0.5))
self.cost3dax = [
@@ -349,6 +361,9 @@
self.axs0[0].relim()
self.axs0[0].autoscale_view()
+ self.axs_velocity.relim()
+ self.axs_velocity.autoscale_view()
+
self.axs0[1].relim()
self.axs0[1].autoscale_view()
diff --git a/frc971/control_loops/swerve/velocity_controller/model.py b/frc971/control_loops/swerve/velocity_controller/model.py
index cb6d9ba..0d8a410 100644
--- a/frc971/control_loops/swerve/velocity_controller/model.py
+++ b/frc971/control_loops/swerve/velocity_controller/model.py
@@ -97,9 +97,10 @@
@nn.compact
def __call__(self,
observation: ArrayLike,
+ R: ArrayLike,
deterministic: bool = False,
rng: PRNGKey | None = None):
- x = observation
+ x = jax.numpy.hstack((observation, R))
# Apply the dense layers
for i, hidden_size in enumerate(self.hidden_sizes):
x = nn.Dense(
@@ -168,9 +169,10 @@
activation: Callable = nn.activation.tanh
@nn.compact
- def __call__(self, observation: ArrayLike, action: ArrayLike):
+ def __call__(self, observation: ArrayLike, R: ArrayLike,
+ action: ArrayLike):
# Estimate Q with a simple multi layer dense network.
- x = jax.numpy.hstack((observation, action))
+ x = jax.numpy.hstack((observation, R, action))
for i, hidden_size in enumerate(self.hidden_sizes):
x = nn.Dense(
name=f'denselayer{i}',
@@ -222,25 +224,29 @@
rng: PRNGKey,
params: flax.core.FrozenDict[str, typing.Any],
observation: ArrayLike,
+ R: ArrayLike,
deterministic: bool = False):
return self.pi_apply_fn(
{'params': params['pi']},
observation=self.problem.unwrap_angles(observation),
+ R=R,
deterministic=deterministic,
rngs={'pi': rng})
def q1_apply(self, params: flax.core.FrozenDict[str, typing.Any],
- observation: ArrayLike, action: ArrayLike):
+ observation: ArrayLike, R: ArrayLike, action: ArrayLike):
return self.q1_apply_fn(
{'params': params['q1']},
observation=self.problem.unwrap_angles(observation),
+ R=R,
action=action)
def q2_apply(self, params: flax.core.FrozenDict[str, typing.Any],
- observation: ArrayLike, action: ArrayLike):
+ observation: ArrayLike, R: ArrayLike, action: ArrayLike):
return self.q2_apply_fn(
{'params': params['q2']},
observation=self.problem.unwrap_angles(observation),
+ R=R,
action=action)
def pi_apply_gradients(self, step: int, grads):
@@ -389,15 +395,18 @@
pi_params = pi.init(
pi_rng,
observation=jax.numpy.ones([problem.num_unwrapped_states]),
+ R=jax.numpy.ones([problem.num_goals]),
)['params']
q1_params = q1.init(
q1_rng,
observation=jax.numpy.ones([problem.num_unwrapped_states]),
+ R=jax.numpy.ones([problem.num_goals]),
action=jax.numpy.ones([problem.num_outputs]),
)['params']
q2_params = q2.init(
q2_rng,
observation=jax.numpy.ones([problem.num_unwrapped_states]),
+ R=jax.numpy.ones([problem.num_goals]),
action=jax.numpy.ones([problem.num_outputs]),
)['params']
diff --git a/frc971/control_loops/swerve/velocity_controller/physics.py b/frc971/control_loops/swerve/velocity_controller/physics.py
index d7bf7b4..450b692 100644
--- a/frc971/control_loops/swerve/velocity_controller/physics.py
+++ b/frc971/control_loops/swerve/velocity_controller/physics.py
@@ -44,20 +44,16 @@
def unwrap_angles(self, X: jax.typing.ArrayLike):
return X
- @partial(jax.jit, static_argnums=[0])
def xdot(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike):
raise NotImplemented("xdot not implemented")
- @partial(jax.jit, static_argnums=[0])
def cost(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike,
goal: jax.typing.ArrayLike):
raise NotImplemented("cost not implemented")
- @partial(jax.jit, static_argnums=[0])
def random_states(self, rng: PRNGKey, dimensions=None):
raise NotImplemented("random_states not implemented")
- #@partial(jax.jit, static_argnums=[0])
def random_actions(self, rng: PRNGKey, dimensions=None):
"""Produces a uniformly random action in the action space."""
return jax.random.uniform(
@@ -67,6 +63,10 @@
maxval=self.action_limit,
)
+ def random_goals(self, rng: PRNGKey, dimensions=None):
+ """Produces a random goal in the goal space."""
+ raise NotImplemented("random_states not implemented")
+
class TurretProblem(Problem):
@@ -77,22 +77,19 @@
num_goals=2,
action_limit=30.0)
- @partial(jax.jit, static_argnums=[0])
def xdot(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike):
A = jax.numpy.array([[0., 1.], [0., -36.85154548]])
B = jax.numpy.array([[0.], [56.08534375]])
return A @ X + B @ U
- @partial(jax.jit, static_argnums=[0])
def cost(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike,
goal: jax.typing.ArrayLike):
Q = jax.numpy.array([[2.77777778, 0.], [0., 0.01]])
R = jax.numpy.array([[0.00694444]])
- return X.T @ Q @ X + U.T @ R @ U
+ return (X - goal).T @ Q @ (X - goal) + U.T @ R @ U
- #@partial(jax.jit, static_argnums=[0])
def random_states(self, rng: PRNGKey, dimensions=None):
rng1, rng2 = jax.random.split(rng)
@@ -103,3 +100,12 @@
jax.random.uniform(rng2, (dimensions or FLAGS.num_agents, 1),
minval=-10.0,
maxval=10.0)))
+
+ def random_goals(self, rng: PRNGKey, dimensions=None):
+ """Produces a random goal in the goal space."""
+ return jax.numpy.hstack((
+ jax.random.uniform(rng, (dimensions or FLAGS.num_agents, 1),
+ minval=-0.1,
+ maxval=0.1),
+ jax.numpy.zeros((dimensions or FLAGS.num_agents, 1)),
+ ))
diff --git a/frc971/control_loops/swerve/velocity_controller/train.py b/frc971/control_loops/swerve/velocity_controller/train.py
index bbe2f5d..f5ba092 100644
--- a/frc971/control_loops/swerve/velocity_controller/train.py
+++ b/frc971/control_loops/swerve/velocity_controller/train.py
@@ -121,18 +121,22 @@
actions = data['actions']
rewards = data['rewards']
observations2 = data['observations2']
+ R = data['goals']
# Compute the ending actions from the current network.
action2, logp_pi2, _, _ = state.pi_apply(rng=rng,
params=params,
- observation=observations2)
+ observation=observations2,
+ R=R)
# Compute target network Q values
q1_pi_target = state.q1_apply(state.target_params,
observation=observations2,
+ R=R,
action=action2)
q2_pi_target = state.q2_apply(state.target_params,
observation=observations2,
+ R=R,
action=action2)
q_pi_target = jax.numpy.minimum(q1_pi_target, q2_pi_target)
@@ -143,8 +147,8 @@
(q_pi_target - alpha * logp_pi2))
# Compute the starting Q values from the Q network being optimized.
- q1 = state.q1_apply(params, observation=observations1, action=actions)
- q2 = state.q2_apply(params, observation=observations1, action=actions)
+ q1 = state.q1_apply(params, observation=observations1, R=R, action=actions)
+ q2 = state.q2_apply(params, observation=observations1, R=R, action=actions)
# Mean squared error loss against Bellman backup
q1_loss = ((q1 - bellman_backup)**2).mean()
@@ -169,17 +173,21 @@
def compute_loss_pi(state: TrainState, rng: PRNGKey, params, data: ArrayLike):
"""Computes the Soft Actor-Critic loss for pi."""
observations1 = data['observations1']
+ R = data['goals']
# TODO(austin): We've got differentiable policy and differentiable physics. Can we use those here? Have Q learn the future, not the current step?
# Compute the action
pi, logp_pi, _, _ = state.pi_apply(rng=rng,
params=params,
- observation=observations1)
+ observation=observations1,
+ R=R)
q1_pi = state.q1_apply(jax.lax.stop_gradient(params),
observation=observations1,
+ R=R,
action=pi)
q2_pi = state.q2_apply(jax.lax.stop_gradient(params),
observation=observations1,
+ R=R,
action=pi)
# And compute the Q of that action.
@@ -286,12 +294,7 @@
state.problem.random_states(initialization_rng, FLAGS.num_agents),
state.sharding)
- R = jax.numpy.hstack((
- jax.random.uniform(goal_rng, (FLAGS.num_agents, 1),
- minval=-0.1,
- maxval=0.1),
- jax.numpy.zeros((FLAGS.num_agents, 1)),
- ))
+ R = state.problem.random_goals(goal_rng, FLAGS.num_agents)
def loop(i, val):
"""Runs 1 step of our simulation."""
@@ -309,6 +312,7 @@
rng=action_rng,
params=state.params,
observation=observation,
+ R=R,
deterministic=False)
return pi_action
@@ -336,6 +340,7 @@
'observations2': observation2,
'actions': pi_action,
'rewards': reward.reshape((FLAGS.num_agents, 1)),
+ 'goals': R,
})
return observation2, pi_rng, replay_buffer_state
@@ -432,6 +437,8 @@
jax.numpy.zeros((problem.num_outputs, )),
'rewards':
jax.numpy.zeros((1, )),
+ 'goals':
+ jax.numpy.zeros((problem.num_states, )),
})
replay_buffer_state_sharding = nn.get_sharding(replay_buffer_state,