Merge changes I04cede20,Ie54cd037,I6756d4ba,I6a01afa5 into main
* changes:
Prep lqr_plot for plotting swerve
Provide a goal to our SAC implementation
Split Soft Actor-Criticl problem out into a class
Add plotly and tensorflow_probability
diff --git a/frc971/control_loops/swerve/velocity_controller/lqr_plot.py b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
index 30f5630..bd6674c 100644
--- a/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
+++ b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
@@ -75,18 +75,6 @@
return checkpoints.restore_checkpoint(workdir, state)
-# Hard-coded simulation parameters for the turret.
-dt = 0.005
-A = numpy.matrix([[1., 0.00456639], [0., 0.83172142]])
-B = numpy.matrix([[0.00065992], [0.25610763]])
-
-Q = numpy.matrix([[2.77777778, 0.], [0., 0.01]])
-R = numpy.matrix([[0.00694444]])
-
-# Compute the optimal LQR cost + controller.
-F, P = controls.dlqr(A, B, Q, R, optimal_cost_function=True)
-
-
def generate_data(step=None):
grid_X = numpy.arange(-1, 1, 0.1)
grid_Y = numpy.arange(-10, 10, 0.1)
@@ -94,21 +82,18 @@
grid_X = jax.numpy.array(grid_X)
grid_Y = jax.numpy.array(grid_Y)
# Load the training state.
- physics_constants = jax_dynamics.Coefficients()
+ problem = physics.TurretProblem()
rng = jax.random.key(0)
rng, init_rng = jax.random.split(rng)
- state = create_train_state(init_rng, physics_constants,
- FLAGS.q_learning_rate, FLAGS.pi_learning_rate)
+ state = create_train_state(init_rng, problem, FLAGS.q_learning_rate,
+ FLAGS.pi_learning_rate)
state = restore_checkpoint(state, FLAGS.workdir)
if step is not None and state.step == step:
return None
- print('F:', F)
- print('P:', P)
-
X = jax.numpy.array([1.0, 0.0])
X_lqr = X.copy()
goal = jax.numpy.array([0.0, 0.0])
@@ -121,34 +106,38 @@
def compute_q1(X, Y):
return state.q1_apply(
state.params,
- unwrap_angles(jax.numpy.array([X, Y])),
- jax.numpy.array([0.]),
+ observation=state.problem.unwrap_angles(jax.numpy.array([X, Y])),
+ R=goal,
+ action=jax.numpy.array([0.]),
)[0]
def compute_q2(X, Y):
return state.q2_apply(
state.params,
- unwrap_angles(jax.numpy.array([X, Y])),
- jax.numpy.array([0.]),
+ observation=state.problem.unwrap_angles(jax.numpy.array([X, Y])),
+ R=goal,
+ action=jax.numpy.array([0.]),
)[0]
def lqr_cost(X, Y):
- x = jax.numpy.array([X, Y])
- return -x.T @ jax.numpy.array(P) @ x
+ x = jax.numpy.array([X, Y]) - goal
+ return -x.T @ jax.numpy.array(problem.P) @ x
def compute_q(params, x, y):
- X = unwrap_angles(jax.numpy.array([x, y]))
+ X = state.problem.unwrap_angles(jax.numpy.array([x, y]))
return jax.numpy.minimum(
state.q1_apply(
params,
- X,
- jax.numpy.array([0.]),
+ observation=X,
+ R=goal,
+ action=jax.numpy.array([0.]),
)[0],
state.q2_apply(
params,
- X,
- jax.numpy.array([0.]),
+ observation=X,
+ R=goal,
+ action=jax.numpy.array([0.]),
)[0])
cost_grid1 = jax.vmap(jax.vmap(compute_q1))(grid_X, grid_Y)
@@ -160,20 +149,17 @@
grid_Y)
lqr_cost_grid = jax.vmap(jax.vmap(lqr_cost))(grid_X, grid_Y)
- # TODO(austin): Stuff to figure out:
- # 3: Make it converge faster. Use both GPUs better?
- # 4: Can we feed in a reference position and get it to learn to stabilize there?
-
# Now compute the two controller surfaces.
def compute_lqr_U(X, Y):
x = jax.numpy.array([X, Y])
- return (-jax.numpy.array(F.reshape((2, ))) @ x)[0]
+ return (-jax.numpy.array(problem.F.reshape((2, ))) @ x)[0]
def compute_pi_U(X, Y):
x = jax.numpy.array([X, Y])
U, _, _, _ = state.pi_apply(rng,
state.params,
- observation=unwrap_angles(x),
+ observation=state.problem.unwrap_angles(x),
+ R=goal,
deterministic=True)
return U[0]
@@ -185,16 +171,24 @@
# Now simulate the robot, accumulating up things to plot.
def loop(i, val):
X, X_lqr, data, params = val
- t = data.t.at[i].set(i * dt)
+ t = data.t.at[i].set(i * problem.dt)
U, _, _, _ = state.pi_apply(rng,
params,
- observation=unwrap_angles(X),
+ observation=state.problem.unwrap_angles(X),
+ R=goal,
deterministic=True)
- U_lqr = F @ (goal - X_lqr)
+ U_lqr = problem.F @ (goal - X_lqr)
- cost = jax.numpy.minimum(state.q1_apply(params, unwrap_angles(X), U),
- state.q2_apply(params, unwrap_angles(X), U))
+ cost = jax.numpy.minimum(
+ state.q1_apply(params,
+ observation=state.problem.unwrap_angles(X),
+ R=goal,
+ action=U),
+ state.q2_apply(params,
+ observation=state.problem.unwrap_angles(X),
+ R=goal,
+ action=U))
U_plot = data.U.at[i, :].set(U)
U_lqr_plot = data.U_lqr.at[i, :].set(U_lqr)
@@ -203,11 +197,11 @@
cost_plot = data.cost.at[i, :].set(cost)
cost_lqr_plot = data.cost_lqr.at[i, :].set(lqr_cost(*X_lqr))
- X = A @ X + B @ U
- X_lqr = A @ X_lqr + B @ U_lqr
+ X = problem.A @ X + problem.B @ U
+ X_lqr = problem.A @ X_lqr + problem.B @ U_lqr
- reward = data.reward - physics.state_cost(X, U, goal)
- reward_lqr = data.reward_lqr - physics.state_cost(X_lqr, U_lqr, goal)
+ reward = data.reward - state.problem.cost(X, U, goal)
+ reward_lqr = data.reward_lqr - state.problem.cost(X_lqr, U_lqr, goal)
return X, X_lqr, data._replace(
t=t,
@@ -230,10 +224,10 @@
X, X_lqr, data, params = integrate(
Data(
t=jax.numpy.zeros((FLAGS.horizon, )),
- X=jax.numpy.zeros((FLAGS.horizon, physics.NUM_STATES)),
- X_lqr=jax.numpy.zeros((FLAGS.horizon, physics.NUM_STATES)),
- U=jax.numpy.zeros((FLAGS.horizon, physics.NUM_OUTPUTS)),
- U_lqr=jax.numpy.zeros((FLAGS.horizon, physics.NUM_OUTPUTS)),
+ X=jax.numpy.zeros((FLAGS.horizon, state.problem.num_states)),
+ X_lqr=jax.numpy.zeros((FLAGS.horizon, state.problem.num_states)),
+ U=jax.numpy.zeros((FLAGS.horizon, state.problem.num_outputs)),
+ U_lqr=jax.numpy.zeros((FLAGS.horizon, state.problem.num_outputs)),
cost=jax.numpy.zeros((FLAGS.horizon, 1)),
cost_lqr=jax.numpy.zeros((FLAGS.horizon, 1)),
q1_grid=cost_grid1,
@@ -284,24 +278,32 @@
self.fig0, self.axs0 = pylab.subplots(3)
self.fig0.supxlabel('Seconds')
- self.x, = self.axs0[0].plot([], [], label="x")
- self.v, = self.axs0[0].plot([], [], label="v")
- self.x_lqr, = self.axs0[0].plot([], [], label="x_lqr")
- self.v_lqr, = self.axs0[0].plot([], [], label="v_lqr")
+ self.axs_velocity = self.axs0[0].twinx()
- self.axs0[0].set_ylabel('Velocity')
+ self.x, = self.axs0[0].plot([], [], label="x")
+ self.x_lqr, = self.axs0[0].plot([], [], label="x_lqr")
+
+ self.axs0[0].set_ylabel('Position')
self.axs0[0].legend()
+ self.axs0[0].grid()
+
+ self.v, = self.axs_velocity.plot([], [], label="v", color='C2')
+ self.v_lqr, = self.axs_velocity.plot([], [], label="v_lqr", color='C3')
+ self.axs_velocity.set_ylabel('Velocity')
+ self.axs_velocity.legend()
self.uaxis, = self.axs0[1].plot([], [], label="U")
self.uaxis_lqr, = self.axs0[1].plot([], [], label="U_lqr")
self.axs0[1].set_ylabel('Amps')
self.axs0[1].legend()
+ self.axs0[1].grid()
self.costaxis, = self.axs0[2].plot([], [], label="cost")
self.costlqraxis, = self.axs0[2].plot([], [], label="cost lqr")
self.axs0[2].set_ylabel('Cost')
self.axs0[2].legend()
+ self.axs0[2].grid()
self.costfig = pyplot.figure(figsize=pyplot.figaspect(0.5))
self.cost3dax = [
@@ -344,6 +346,9 @@
self.axs0[0].relim()
self.axs0[0].autoscale_view()
+ self.axs_velocity.relim()
+ self.axs_velocity.autoscale_view()
+
self.axs0[1].relim()
self.axs0[1].autoscale_view()
diff --git a/frc971/control_loops/swerve/velocity_controller/main.py b/frc971/control_loops/swerve/velocity_controller/main.py
index 7ef687c..a0d248d 100644
--- a/frc971/control_loops/swerve/velocity_controller/main.py
+++ b/frc971/control_loops/swerve/velocity_controller/main.py
@@ -24,6 +24,7 @@
import jax
import tensorflow as tf
from frc971.control_loops.swerve import jax_dynamics
+from frc971.control_loops.swerve.velocity_controller import physics
jax._src.deprecations.accelerate('tracer-hash')
# Enable the compilation cache
@@ -63,8 +64,8 @@
FLAGS.workdir,
)
- physics_constants = jax_dynamics.Coefficients()
- state = train.train(FLAGS.workdir, physics_constants)
+ problem = physics.TurretProblem()
+ state = train.train(FLAGS.workdir, problem)
if __name__ == '__main__':
diff --git a/frc971/control_loops/swerve/velocity_controller/model.py b/frc971/control_loops/swerve/velocity_controller/model.py
index ae67efe..0d8a410 100644
--- a/frc971/control_loops/swerve/velocity_controller/model.py
+++ b/frc971/control_loops/swerve/velocity_controller/model.py
@@ -36,8 +36,8 @@
absl.flags.DEFINE_float(
'final_q_learning_rate',
- default=0.00002,
- help='Training learning rate.',
+ default=0.001,
+ help='Fraction of --q_learning_rate to reduce by by the end.',
)
absl.flags.DEFINE_float(
@@ -48,8 +48,8 @@
absl.flags.DEFINE_float(
'final_pi_learning_rate',
- default=0.00002,
- help='Training learning rate.',
+ default=0.001,
+ help='Fraction of --pi_learning_rate to reduce by by the end.',
)
absl.flags.DEFINE_integer(
@@ -92,14 +92,15 @@
activation: Callable = nn.activation.relu
# Max action we can apply
- action_limit: float = 30.0
+ action_limit: float = 1
@nn.compact
def __call__(self,
- observations: ArrayLike,
+ observation: ArrayLike,
+ R: ArrayLike,
deterministic: bool = False,
rng: PRNGKey | None = None):
- x = observations
+ x = jax.numpy.hstack((observation, R))
# Apply the dense layers
for i, hidden_size in enumerate(self.hidden_sizes):
x = nn.Dense(
@@ -168,9 +169,10 @@
activation: Callable = nn.activation.tanh
@nn.compact
- def __call__(self, observations, actions):
+ def __call__(self, observation: ArrayLike, R: ArrayLike,
+ action: ArrayLike):
# Estimate Q with a simple multi layer dense network.
- x = jax.numpy.hstack((observations, actions))
+ x = jax.numpy.hstack((observation, R, action))
for i, hidden_size in enumerate(self.hidden_sizes):
x = nn.Dense(
name=f'denselayer{i}',
@@ -185,8 +187,7 @@
class TrainState(flax.struct.PyTreeNode):
- physics_constants: jax_dynamics.CoefficientsType = flax.struct.field(
- pytree_node=False)
+ problem: Problem = flax.struct.field(pytree_node=False)
step: int | jax.Array = flax.struct.field(pytree_node=True)
substep: int | jax.Array = flax.struct.field(pytree_node=True)
@@ -221,23 +222,34 @@
def pi_apply(self,
rng: PRNGKey,
- params,
- observation,
+ params: flax.core.FrozenDict[str, typing.Any],
+ observation: ArrayLike,
+ R: ArrayLike,
deterministic: bool = False):
- return self.pi_apply_fn({'params': params['pi']},
- physics.unwrap_angles(observation),
- deterministic=deterministic,
- rngs={'pi': rng})
+ return self.pi_apply_fn(
+ {'params': params['pi']},
+ observation=self.problem.unwrap_angles(observation),
+ R=R,
+ deterministic=deterministic,
+ rngs={'pi': rng})
- def q1_apply(self, params, observation, action):
- return self.q1_apply_fn({'params': params['q1']},
- physics.unwrap_angles(observation), action)
+ def q1_apply(self, params: flax.core.FrozenDict[str, typing.Any],
+ observation: ArrayLike, R: ArrayLike, action: ArrayLike):
+ return self.q1_apply_fn(
+ {'params': params['q1']},
+ observation=self.problem.unwrap_angles(observation),
+ R=R,
+ action=action)
- def q2_apply(self, params, observation, action):
- return self.q2_apply_fn({'params': params['q2']},
- physics.unwrap_angles(observation), action)
+ def q2_apply(self, params: flax.core.FrozenDict[str, typing.Any],
+ observation: ArrayLike, R: ArrayLike, action: ArrayLike):
+ return self.q2_apply_fn(
+ {'params': params['q2']},
+ observation=self.problem.unwrap_angles(observation),
+ R=R,
+ action=action)
- def pi_apply_gradients(self, step, grads):
+ def pi_apply_gradients(self, step: int, grads):
updates, new_pi_opt_state = self.pi_tx.update(grads, self.pi_opt_state,
self.params)
new_params = optax.apply_updates(self.params, updates)
@@ -249,7 +261,7 @@
pi_opt_state=new_pi_opt_state,
)
- def q_apply_gradients(self, step, grads):
+ def q_apply_gradients(self, step: int, grads):
updates, new_q_opt_state = self.q_tx.update(grads, self.q_opt_state,
self.params)
new_params = optax.apply_updates(self.params, updates)
@@ -291,9 +303,8 @@
)
@classmethod
- def create(cls, *, physics_constants: jax_dynamics.CoefficientsType,
- params, pi_tx, q_tx, alpha_tx, pi_apply_fn, q1_apply_fn,
- q2_apply_fn, **kwargs):
+ def create(cls, *, problem: Problem, params, pi_tx, q_tx, alpha_tx,
+ pi_apply_fn, q1_apply_fn, q2_apply_fn, **kwargs):
"""Creates a new instance with ``step=0`` and initialized ``opt_state``."""
pi_tx = optax.multi_transform(
@@ -343,7 +354,7 @@
length=FLAGS.replay_size)
result = cls(
- physics_constants=physics_constants,
+ problem=problem,
step=0,
substep=0,
params=params,
@@ -357,7 +368,7 @@
q_opt_state=q_opt_state,
alpha_tx=alpha_tx,
alpha_opt_state=alpha_opt_state,
- target_entropy=-physics.NUM_STATES,
+ target_entropy=-problem.num_states,
mesh=mesh,
sharding=sharding,
replicated_sharding=replicated_sharding,
@@ -367,10 +378,12 @@
return jax.device_put(result, replicated_sharding)
-def create_train_state(rng, physics_constants: jax_dynamics.CoefficientsType,
- q_learning_rate, pi_learning_rate):
+def create_train_state(rng: PRNGKey, problem: Problem, q_learning_rate,
+ pi_learning_rate):
"""Creates initial `TrainState`."""
- pi = SquashedGaussianMLPActor(action_space=physics.NUM_OUTPUTS)
+ pi = SquashedGaussianMLPActor(activation=nn.activation.gelu,
+ action_space=problem.num_outputs,
+ action_limit=problem.action_limit)
# We want q1 and q2 to have different network architectures so they pick up differnet things.
q1 = MLPQFunction(activation=nn.activation.gelu, hidden_sizes=[128, 256])
q2 = MLPQFunction(activation=nn.activation.gelu, hidden_sizes=[256, 128])
@@ -381,17 +394,20 @@
pi_params = pi.init(
pi_rng,
- jax.numpy.ones([physics.NUM_UNWRAPPED_STATES]),
+ observation=jax.numpy.ones([problem.num_unwrapped_states]),
+ R=jax.numpy.ones([problem.num_goals]),
)['params']
q1_params = q1.init(
q1_rng,
- jax.numpy.ones([physics.NUM_UNWRAPPED_STATES]),
- jax.numpy.ones([physics.NUM_OUTPUTS]),
+ observation=jax.numpy.ones([problem.num_unwrapped_states]),
+ R=jax.numpy.ones([problem.num_goals]),
+ action=jax.numpy.ones([problem.num_outputs]),
)['params']
q2_params = q2.init(
q2_rng,
- jax.numpy.ones([physics.NUM_UNWRAPPED_STATES]),
- jax.numpy.ones([physics.NUM_OUTPUTS]),
+ observation=jax.numpy.ones([problem.num_unwrapped_states]),
+ R=jax.numpy.ones([problem.num_goals]),
+ action=jax.numpy.ones([problem.num_outputs]),
)['params']
if FLAGS.alpha < 0.0:
@@ -411,7 +427,7 @@
alpha_tx = optax.sgd(learning_rate=q_learning_rate)
result = TrainState.create(
- physics_constants=physics_constants,
+ problem=problem,
params=init_params(rng),
pi_tx=pi_tx,
q_tx=q_tx,
diff --git a/frc971/control_loops/swerve/velocity_controller/physics.py b/frc971/control_loops/swerve/velocity_controller/physics.py
index 9454f14..d6515fe 100644
--- a/frc971/control_loops/swerve/velocity_controller/physics.py
+++ b/frc971/control_loops/swerve/velocity_controller/physics.py
@@ -1,69 +1,122 @@
-import jax
+import jax, numpy
from functools import partial
-from frc971.control_loops.swerve import dynamics
from absl import logging
-from frc971.control_loops.swerve import jax_dynamics
+from frc971.control_loops.swerve import dynamics, jax_dynamics
+from frc971.control_loops.python import controls
+from flax.typing import PRNGKey
-@partial(jax.jit, static_argnums=[0])
-def xdot_physics(physics_constants: jax_dynamics.CoefficientsType, X, U):
- A = jax.numpy.array([[0., 1.], [0., -36.85154548]])
- B = jax.numpy.array([[0.], [56.08534375]])
+class Problem(object):
- return A @ X + B @ U
+ def __init__(self, num_states: int, num_unwrapped_states: int,
+ num_outputs: int, num_goals: int, action_limit: float):
+ self.num_states = num_states
+ self.num_unwrapped_states = num_unwrapped_states
+ self.num_outputs = num_outputs
+ self.num_goals = num_goals
+ self.action_limit = action_limit
+ self.dt = 0.005
+ def integrate_dynamics(self, X: jax.typing.ArrayLike,
+ U: jax.typing.ArrayLike):
+ m = 2 # RK4 steps per interval
+ dt = self.dt / m
-def unwrap_angles(X):
- return X
+ def iteration(i, X):
+ weights = jax.numpy.array([[0.0, 0.5, 0.5, 1.0],
+ [1.0, 2.0, 2.0, 1.0]])
+ def rk_iteration(i, val):
+ kx_previous, weighted_sum = val
+ kx = self.xdot(X + dt * weights[0, i] * kx_previous, U)
+ weighted_sum += dt * weights[1, i] * kx / 6.0
+ return (kx, weighted_sum)
-@jax.jit
-def swerve_cost(X: jax.typing.ArrayLike, U: jax.typing.ArrayLike,
- goal: jax.typing.ArrayLike):
-
- Q = jax.numpy.array([[2.77777778, 0.], [0., 0.01]])
- R = jax.numpy.array([[0.00694444]])
-
- return (X).T @ Q @ (X) + U.T @ R @ U
-
-
-@partial(jax.jit, static_argnums=[0])
-def integrate_dynamics(physics_constants: jax_dynamics.CoefficientsType, X, U):
- m = 2 # RK4 steps per interval
- dt = 0.005 / m
-
- def iteration(i, X):
- weights = jax.numpy.array([[0.0, 0.5, 0.5, 1.0], [1.0, 2.0, 2.0, 1.0]])
-
- def rk_iteration(i, val):
- kx_previous, weighted_sum = val
- kx = xdot_physics(physics_constants,
- X + dt * weights[0, i] * kx_previous, U)
- weighted_sum += dt * weights[1, i] * kx / 6.0
- return (kx, weighted_sum)
+ return jax.lax.fori_loop(lower=0,
+ upper=4,
+ body_fun=rk_iteration,
+ init_val=(X, X))[1]
return jax.lax.fori_loop(lower=0,
- upper=4,
- body_fun=rk_iteration,
- init_val=(X, X))[1]
+ upper=m,
+ body_fun=iteration,
+ init_val=X)
- return jax.lax.fori_loop(lower=0, upper=m, body_fun=iteration, init_val=X)
+ def unwrap_angles(self, X: jax.typing.ArrayLike):
+ return X
+
+ def xdot(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike):
+ raise NotImplemented("xdot not implemented")
+
+ def cost(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike,
+ goal: jax.typing.ArrayLike):
+ raise NotImplemented("cost not implemented")
+
+ def random_states(self, rng: PRNGKey, dimensions=None):
+ raise NotImplemented("random_states not implemented")
+
+ def random_actions(self, rng: PRNGKey, dimensions=None):
+ """Produces a uniformly random action in the action space."""
+ return jax.random.uniform(
+ rng,
+ (dimensions or FLAGS.num_agents, self.num_outputs),
+ minval=-self.action_limit,
+ maxval=self.action_limit,
+ )
+
+ def random_goals(self, rng: PRNGKey, dimensions=None):
+ """Produces a random goal in the goal space."""
+ raise NotImplemented("random_states not implemented")
-state_cost = swerve_cost
-NUM_STATES = 2
-NUM_UNWRAPPED_STATES = 2
-NUM_OUTPUTS = 1
-ACTION_LIMIT = 10.0
+class TurretProblem(Problem):
+ def __init__(self):
+ super().__init__(num_states=2,
+ num_unwrapped_states=2,
+ num_outputs=1,
+ num_goals=2,
+ action_limit=30.0)
+ self.A = numpy.matrix([[1., 0.00456639], [0., 0.83172142]])
+ self.B = numpy.matrix([[0.00065992], [0.25610763]])
-def random_states(rng, dimensions=None):
- rng1, rng2 = jax.random.split(rng)
+ self.Q = numpy.matrix([[2.77777778, 0.], [0., 0.01]])
+ self.R = numpy.matrix([[0.00694444]])
- return jax.numpy.hstack(
- (jax.random.uniform(rng1, (dimensions or FLAGS.num_agents, 1),
- minval=-1.0,
- maxval=1.0),
- jax.random.uniform(rng2, (dimensions or FLAGS.num_agents, 1),
- minval=-10.0,
- maxval=10.0)))
+ # Compute the optimal LQR cost + controller.
+ self.F, self.P = controls.dlqr(self.A,
+ self.B,
+ self.Q,
+ self.R,
+ optimal_cost_function=True)
+
+ def xdot(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike):
+ A_continuous = jax.numpy.array([[0., 1.], [0., -36.85154548]])
+ B_continuous = jax.numpy.array([[0.], [56.08534375]])
+
+ return A_continuous @ X + B_continuous @ U
+
+ def cost(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike,
+ goal: jax.typing.ArrayLike):
+ return (X - goal).T @ jax.numpy.array(
+ self.Q) @ (X - goal) + U.T @ jax.numpy.array(self.R) @ U
+
+ def random_states(self, rng: PRNGKey, dimensions=None):
+ rng1, rng2 = jax.random.split(rng)
+
+ return jax.numpy.hstack(
+ (jax.random.uniform(rng1, (dimensions or FLAGS.num_agents, 1),
+ minval=-1.0,
+ maxval=1.0),
+ jax.random.uniform(rng2, (dimensions or FLAGS.num_agents, 1),
+ minval=-10.0,
+ maxval=10.0)))
+
+ def random_goals(self, rng: PRNGKey, dimensions=None):
+ """Produces a random goal in the goal space."""
+ return jax.numpy.hstack((
+ jax.random.uniform(rng, (dimensions or FLAGS.num_agents, 1),
+ minval=-0.1,
+ maxval=0.1),
+ jax.numpy.zeros((dimensions or FLAGS.num_agents, 1)),
+ ))
diff --git a/frc971/control_loops/swerve/velocity_controller/train.py b/frc971/control_loops/swerve/velocity_controller/train.py
index 5002eee..f5ba092 100644
--- a/frc971/control_loops/swerve/velocity_controller/train.py
+++ b/frc971/control_loops/swerve/velocity_controller/train.py
@@ -78,16 +78,6 @@
)
-def random_actions(rng, dimensions=None):
- """Produces a uniformly random action in the action space."""
- return jax.random.uniform(
- rng,
- (dimensions or FLAGS.num_agents, physics.NUM_OUTPUTS),
- minval=-physics.ACTION_LIMIT,
- maxval=physics.ACTION_LIMIT,
- )
-
-
def save_checkpoint(state: TrainState, workdir: str):
"""Saves a checkpoint in the workdir."""
# TODO(austin): use orbax directly.
@@ -131,15 +121,23 @@
actions = data['actions']
rewards = data['rewards']
observations2 = data['observations2']
+ R = data['goals']
# Compute the ending actions from the current network.
action2, logp_pi2, _, _ = state.pi_apply(rng=rng,
params=params,
- observation=observations2)
+ observation=observations2,
+ R=R)
# Compute target network Q values
- q1_pi_target = state.q1_apply(state.target_params, observations2, action2)
- q2_pi_target = state.q2_apply(state.target_params, observations2, action2)
+ q1_pi_target = state.q1_apply(state.target_params,
+ observation=observations2,
+ R=R,
+ action=action2)
+ q2_pi_target = state.q2_apply(state.target_params,
+ observation=observations2,
+ R=R,
+ action=action2)
q_pi_target = jax.numpy.minimum(q1_pi_target, q2_pi_target)
alpha = jax.numpy.exp(params['logalpha'])
@@ -149,8 +147,8 @@
(q_pi_target - alpha * logp_pi2))
# Compute the starting Q values from the Q network being optimized.
- q1 = state.q1_apply(params, observations1, actions)
- q2 = state.q2_apply(params, observations1, actions)
+ q1 = state.q1_apply(params, observation=observations1, R=R, action=actions)
+ q2 = state.q2_apply(params, observation=observations1, R=R, action=actions)
# Mean squared error loss against Bellman backup
q1_loss = ((q1 - bellman_backup)**2).mean()
@@ -175,14 +173,22 @@
def compute_loss_pi(state: TrainState, rng: PRNGKey, params, data: ArrayLike):
"""Computes the Soft Actor-Critic loss for pi."""
observations1 = data['observations1']
+ R = data['goals']
# TODO(austin): We've got differentiable policy and differentiable physics. Can we use those here? Have Q learn the future, not the current step?
# Compute the action
pi, logp_pi, _, _ = state.pi_apply(rng=rng,
params=params,
- observation=observations1)
- q1_pi = state.q1_apply(jax.lax.stop_gradient(params), observations1, pi)
- q2_pi = state.q2_apply(jax.lax.stop_gradient(params), observations1, pi)
+ observation=observations1,
+ R=R)
+ q1_pi = state.q1_apply(jax.lax.stop_gradient(params),
+ observation=observations1,
+ R=R,
+ action=pi)
+ q2_pi = state.q2_apply(jax.lax.stop_gradient(params),
+ observation=observations1,
+ R=R,
+ action=pi)
# And compute the Q of that action.
q_pi = jax.numpy.minimum(q1_pi, q2_pi)
@@ -277,16 +283,19 @@
@jax.jit
-def collect_experience(state: TrainState, replay_buffer_state, R: ArrayLike,
+def collect_experience(state: TrainState, replay_buffer_state,
step_rng: PRNGKey, step):
"""Collects experience by simulating."""
pi_rng = jax.random.fold_in(step_rng, step)
pi_rng, initialization_rng = jax.random.split(pi_rng)
+ pi_rng, goal_rng = jax.random.split(pi_rng)
observation = jax.lax.with_sharding_constraint(
- physics.random_states(initialization_rng, FLAGS.num_agents),
+ state.problem.random_states(initialization_rng, FLAGS.num_agents),
state.sharding)
+ R = state.problem.random_goals(goal_rng, FLAGS.num_agents)
+
def loop(i, val):
"""Runs 1 step of our simulation."""
observation, pi_rng, replay_buffer_state = val
@@ -295,7 +304,7 @@
def true_fn(i):
# We are at the beginning of the process, pick a random action.
- return random_actions(action_rng, FLAGS.num_agents)
+ return state.problem.random_actions(action_rng, FLAGS.num_agents)
def false_fn(i):
# We are past the beginning of the process, use the trained network.
@@ -303,6 +312,7 @@
rng=action_rng,
params=state.params,
observation=observation,
+ R=R,
deterministic=False)
return pi_action
@@ -314,26 +324,23 @@
)
# Compute the destination state.
- observation2 = jax.vmap(lambda o, pi: physics.integrate_dynamics(
- state.physics_constants, o, pi),
- in_axes=(0, 0))(observation, pi_action)
+ observation2 = jax.vmap(
+ lambda o, pi: state.problem.integrate_dynamics(o, pi),
+ in_axes=(0, 0))(observation, pi_action)
# Soft Actor-Critic is designed to maximize reward. LQR minimizes
# cost. There is nothing which assumes anything about the sign of
# the reward, so use the negative of the cost.
- reward = (-jax.vmap(state_cost)(observation2, pi_action, R)).reshape(
- (FLAGS.num_agents, 1))
+ reward = -jax.vmap(state.problem.cost)(
+ X=observation2, U=pi_action, goal=R)
- logging.info('Observation shape: %s', observation.shape)
- logging.info('Observation2 shape: %s', observation2.shape)
- logging.info('Action shape: %s', pi_action.shape)
- logging.info('Action shape: %s', pi_action.shape)
replay_buffer_state = state.replay_buffer.add(
replay_buffer_state, {
'observations1': observation,
'observations2': observation2,
'actions': pi_action,
- 'rewards': reward,
+ 'rewards': reward.reshape((FLAGS.num_agents, 1)),
+ 'goals': R,
})
return observation2, pi_rng, replay_buffer_state
@@ -379,9 +386,7 @@
return rng, state, q_loss, pi_loss, alpha_loss
-def train(
- workdir: str, physics_constants: jax_dynamics.CoefficientsType
-) -> train_state.TrainState:
+def train(workdir: str, problem: Problem) -> train_state.TrainState:
"""Trains a Soft Actor-Critic controller."""
rng = jax.random.key(0)
rng, r_rng = jax.random.split(rng)
@@ -394,6 +399,14 @@
'batch_size': FLAGS.batch_size,
'horizon': FLAGS.horizon,
'warmup_steps': FLAGS.warmup_steps,
+ 'steps': FLAGS.steps,
+ 'replay_size': FLAGS.replay_size,
+ 'num_agents': FLAGS.num_agents,
+ 'polyak': FLAGS.polyak,
+ 'gamma': FLAGS.gamma,
+ 'alpha': FLAGS.alpha,
+ 'final_q_learning_rate': FLAGS.final_q_learning_rate,
+ 'final_pi_learning_rate': FLAGS.final_pi_learning_rate,
}
# Setup TrainState
@@ -406,7 +419,7 @@
final_learning_rate=FLAGS.final_pi_learning_rate)
state = create_train_state(
init_rng,
- physics_constants,
+ problem,
q_learning_rate=q_learning_rate,
pi_learning_rate=pi_learning_rate,
)
@@ -415,42 +428,23 @@
state_sharding = nn.get_sharding(state, state.mesh)
logging.info(state_sharding)
- # TODO(austin): Let the goal change.
- R = jax.numpy.hstack((
- jax.random.uniform(rng, (FLAGS.num_agents, 1), minval=0.9, maxval=1.1),
- jax.random.uniform(rng, (FLAGS.num_agents, 1), minval=-0.1,
- maxval=0.1),
- jax.numpy.zeros((FLAGS.num_agents, 1)),
- ))
-
replay_buffer_state = state.replay_buffer.init({
'observations1':
- jax.numpy.zeros((physics.NUM_STATES, )),
+ jax.numpy.zeros((problem.num_states, )),
'observations2':
- jax.numpy.zeros((physics.NUM_STATES, )),
+ jax.numpy.zeros((problem.num_states, )),
'actions':
- jax.numpy.zeros((physics.NUM_OUTPUTS, )),
+ jax.numpy.zeros((problem.num_outputs, )),
'rewards':
jax.numpy.zeros((1, )),
+ 'goals':
+ jax.numpy.zeros((problem.num_states, )),
})
replay_buffer_state_sharding = nn.get_sharding(replay_buffer_state,
state.mesh)
logging.info(replay_buffer_state_sharding)
- logging.info('Shape: %s, Sharding %s',
- replay_buffer_state.experience['observations1'].shape,
- replay_buffer_state.experience['observations1'].sharding)
- logging.info('Shape: %s, Sharding %s',
- replay_buffer_state.experience['observations2'].shape,
- replay_buffer_state.experience['observations2'].sharding)
- logging.info('Shape: %s, Sharding %s',
- replay_buffer_state.experience['actions'].shape,
- replay_buffer_state.experience['actions'].sharding)
- logging.info('Shape: %s, Sharding %s',
- replay_buffer_state.experience['rewards'].shape,
- replay_buffer_state.experience['rewards'].sharding)
-
# Number of gradients to accumulate before doing decent.
update_after = FLAGS.batch_size // FLAGS.num_agents
@@ -462,7 +456,6 @@
state, replay_buffer_state = collect_experience(
state,
replay_buffer_state,
- R,
step_rng,
step,
)
diff --git a/tools/python/requirements.lock.txt b/tools/python/requirements.lock.txt
index 4fb225e..d465241 100644
--- a/tools/python/requirements.lock.txt
+++ b/tools/python/requirements.lock.txt
@@ -20,6 +20,7 @@
# tensorflow
# tensorflow-datasets
# tensorflow-metadata
+ # tensorflow-probability
aim==3.24.0 \
--hash=sha256:1266fac2453f2d356d2e39c98d15b46ba7a71ecab14cdaa6eaa6d390f5add898 \
--hash=sha256:1322659dcecb701c264f5067375c74811faf3514bb316e1ca1e53b9de6b62766 \
@@ -400,6 +401,10 @@
# mkdocs
# tensorflow-datasets
# uvicorn
+cloudpickle==3.1.0 \
+ --hash=sha256:81a929b6e3c7335c863c771d673d105f02efdb89dfaba0c90495d1c64796601b \
+ --hash=sha256:fe11acda67f61aaaec473e3afe030feb131d78a43461b718185363384f1ba12e
+ # via tensorflow-probability
clu==0.0.12 \
--hash=sha256:0d183e7d25f7dd0700444510a264e24700e2f068bdabd199ed22866f7e54edba \
--hash=sha256:f71eaa1afbd30f57f7709257ba7e1feb8ad5c1c3dcae3606672a138678bb3ce4
@@ -510,6 +515,10 @@
--hash=sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30 \
--hash=sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c
# via matplotlib
+decorator==5.1.1 \
+ --hash=sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330 \
+ --hash=sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186
+ # via tensorflow-probability
dm-tree==0.1.8 \
--hash=sha256:054b461f8176f4bce7a21f7b1870f873a1ced3bdbe1282c816c550bb43c71fa6 \
--hash=sha256:09964470f76a5201aff2e8f9b26842976de7889300676f927930f6285e256760 \
@@ -557,7 +566,9 @@
--hash=sha256:ea9e59e0451e7d29aece402d9f908f2e2a80922bcde2ebfd5dcb07750fcbfee8 \
--hash=sha256:f7ac31b9aecccb2c6e1ab29706f6ded3eba0c2c69c770322c9c685929c3d6afb \
--hash=sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d
- # via tensorflow-datasets
+ # via
+ # tensorflow-datasets
+ # tensorflow-probability
etils[enp,epath,epy,etree]==1.5.2 \
--hash=sha256:6dc882d355e1e98a5d1a148d6323679dc47c9a5792939b9de72615aa4737eb0b \
--hash=sha256:ba6a3e1aff95c769130776aa176c11540637f5dd881f3b79172a5149b6b1c446
@@ -649,7 +660,9 @@
gast==0.6.0 \
--hash=sha256:52b182313f7330389f72b069ba00f174cfe2a06411099547288839c6cbafbd54 \
--hash=sha256:88fc5300d32c7ac6ca7b515310862f71e6fdf2c029bbec7c66c0f5dd47b6b1fb
- # via tensorflow
+ # via
+ # tensorflow
+ # tensorflow-probability
ghp-import==2.1.0 \
--hash=sha256:8337dd7b50877f163d4c0289bc1f1c7f127550241988d568c1db512c4324a619 \
--hash=sha256:9c535c4c61193c2df8871222567d7fd7e5014d835f97dc7b7439069e2413d343
@@ -1335,6 +1348,7 @@
# tensorboard
# tensorflow
# tensorflow-datasets
+ # tensorflow-probability
# tensorstore
nvidia-cublas-cu12==12.6.1.4 \
--hash=sha256:5dd125ece5469dbdceebe2e9536ad8fc4abd38aa394a7ace42fc8a930a1e81e3 \
@@ -1534,6 +1548,7 @@
# keras
# matplotlib
# mkdocs
+ # plotly
# tensorboard
# tensorflow
pandas==2.2.2 \
@@ -1666,6 +1681,10 @@
# via
# mkdocs-get-deps
# yapf
+plotly==5.24.1 \
+ --hash=sha256:dbc8ac8339d248a4bcc36e08a5659bacfe1b079390b8953533f4eb22169b4bae \
+ --hash=sha256:f67073a1e637eb0dc3e46324d9d51e2fe76e9727c892dde64ddf1e1b51f29089
+ # via -r tools/python/requirements.txt
promise==2.3 \
--hash=sha256:dfd18337c523ba4b6a58801c164c1904a9d4d1b1747c7d5dbf45b693a49d93d0
# via tensorflow-datasets
@@ -2079,6 +2098,7 @@
# python-dateutil
# tensorboard
# tensorflow
+ # tensorflow-probability
sniffio==1.3.1 \
--hash=sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2 \
--hash=sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc
@@ -2148,6 +2168,10 @@
--hash=sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c \
--hash=sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f
# via -r tools/python/requirements.txt
+tenacity==9.0.0 \
+ --hash=sha256:807f37ca97d62aa361264d497b0e31e92b8027044942bfa756160d908320d73b \
+ --hash=sha256:93de0c98785b27fcf659856aa9f54bfbd399e29969b0621bc7f762bd441b4539
+ # via plotly
tensorboard==2.17.1 \
--hash=sha256:253701a224000eeca01eee6f7e978aea7b408f60b91eb0babdb04e78947b773e
# via tensorflow
@@ -2173,7 +2197,9 @@
--hash=sha256:e8d26d6c24ccfb139db1306599257ca8f5cfe254ef2d023bfb667f374a17a64d \
--hash=sha256:ee18b4fcd627c5e872eabb25092af6c808b6ec77948662c88fc5c89a60eb0211 \
--hash=sha256:ef615c133cf4d592a073feda634ccbeb521a554be57de74f8c318d38febbeab5
- # via -r tools/python/requirements.txt
+ # via
+ # -r tools/python/requirements.txt
+ # tf-keras
tensorflow-datasets==4.9.3 \
--hash=sha256:09cd60eccab0d5a9d15f53e76ee0f1b530ee5aa3665e42be621a4810d9fa5db6 \
--hash=sha256:90390077dde2c9e4e240754ddfc5bb50b482946d421c8a34677c3afdb0463427
@@ -2199,6 +2225,9 @@
tensorflow-metadata==1.15.0 \
--hash=sha256:cb84d8e159128aeae7b3f6013ccd7969c69d2e6d1a7b255dbfa6f5344d962986
# via tensorflow-datasets
+tensorflow-probability==0.24.0 \
+ --hash=sha256:8c1774683e38359dbcaf3697e79b7e6a4e69b9c7b3679e78ee18f43e59e5759b
+ # via -r tools/python/requirements.txt
tensorstore==0.1.64 \
--hash=sha256:1a78aedbddccc09ea283b145496da03dbc7eb8693ae4e01074ed791d72b7eac2 \
--hash=sha256:24a4cebaf9d0e75d494342948f68edc971d6bb90e23192ddf8d98397fb1ff3cb \
@@ -2231,6 +2260,10 @@
# via
# tensorflow
# tensorflow-datasets
+tf-keras==2.17.0 \
+ --hash=sha256:cc97717e4dc08487f327b0740a984043a9e0123c7a4e21206711669d3ec41c88 \
+ --hash=sha256:fda97c18da30da0f72a5a7e80f3eee343b09f4c206dad6c57c944fb2cd18560e
+ # via -r tools/python/requirements.txt
toml==0.10.2 \
--hash=sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b \
--hash=sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f
diff --git a/tools/python/requirements.txt b/tools/python/requirements.txt
index 241425a..342d8ab 100644
--- a/tools/python/requirements.txt
+++ b/tools/python/requirements.txt
@@ -41,9 +41,12 @@
tensorflow
tensorflow_datasets
+tensorflow_probability
+tf_keras
# Experience buffer for reinforcement learning
flashbax
# Experiment tracking
aim
+plotly
diff --git a/tools/python/whl_overrides.json b/tools/python/whl_overrides.json
index 249b856..c27c8ac 100644
--- a/tools/python/whl_overrides.json
+++ b/tools/python/whl_overrides.json
@@ -91,6 +91,10 @@
"sha256": "ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28",
"url": "https://software.frc971.org/Build-Dependencies/wheelhouse/click-8.1.7-py3-none-any.whl"
},
+ "cloudpickle==3.1.0": {
+ "sha256": "fe11acda67f61aaaec473e3afe030feb131d78a43461b718185363384f1ba12e",
+ "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/cloudpickle-3.1.0-py3-none-any.whl"
+ },
"clu==0.0.12": {
"sha256": "0d183e7d25f7dd0700444510a264e24700e2f068bdabd199ed22866f7e54edba",
"url": "https://software.frc971.org/Build-Dependencies/wheelhouse/clu-0.0.12-py3-none-any.whl"
@@ -111,6 +115,10 @@
"sha256": "85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30",
"url": "https://software.frc971.org/Build-Dependencies/wheelhouse/cycler-0.12.1-py3-none-any.whl"
},
+ "decorator==5.1.1": {
+ "sha256": "b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186",
+ "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/decorator-5.1.1-py3-none-any.whl"
+ },
"dm_tree==0.1.8": {
"sha256": "181c35521d480d0365f39300542cb6cd7fd2b77351bb43d7acfda15aef63b317",
"url": "https://software.frc971.org/Build-Dependencies/wheelhouse/dm_tree-0.1.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
@@ -399,6 +407,10 @@
"sha256": "2d7a1657e36a80ea911db832a8a6ece5ee53d8de21edd5cc5879af6530b1bfee",
"url": "https://software.frc971.org/Build-Dependencies/wheelhouse/platformdirs-4.2.2-py3-none-any.whl"
},
+ "plotly==5.24.1": {
+ "sha256": "f67073a1e637eb0dc3e46324d9d51e2fe76e9727c892dde64ddf1e1b51f29089",
+ "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/plotly-5.24.1-py3-none-any.whl"
+ },
"promise==2.3": {
"sha256": "d10acd69e1aed4de5840e3915edf51c877dfc7c7ae440fd081019edbf62820a4",
"url": "https://software.frc971.org/Build-Dependencies/wheelhouse/promise-2.3-py3-none-any.whl"
@@ -515,6 +527,10 @@
"sha256": "024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f",
"url": "https://software.frc971.org/Build-Dependencies/wheelhouse/tabulate-0.9.0-py3-none-any.whl"
},
+ "tenacity==9.0.0": {
+ "sha256": "93de0c98785b27fcf659856aa9f54bfbd399e29969b0621bc7f762bd441b4539",
+ "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/tenacity-9.0.0-py3-none-any.whl"
+ },
"tensorboard==2.17.1": {
"sha256": "253701a224000eeca01eee6f7e978aea7b408f60b91eb0babdb04e78947b773e",
"url": "https://software.frc971.org/Build-Dependencies/wheelhouse/tensorboard-2.17.1-py3-none-any.whl"
@@ -539,6 +555,10 @@
"sha256": "cb84d8e159128aeae7b3f6013ccd7969c69d2e6d1a7b255dbfa6f5344d962986",
"url": "https://software.frc971.org/Build-Dependencies/wheelhouse/tensorflow_metadata-1.15.0-py3-none-any.whl"
},
+ "tensorflow_probability==0.24.0": {
+ "sha256": "8c1774683e38359dbcaf3697e79b7e6a4e69b9c7b3679e78ee18f43e59e5759b",
+ "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/tensorflow_probability-0.24.0-py2.py3-none-any.whl"
+ },
"tensorstore==0.1.64": {
"sha256": "72f76231ce12bfd266358a096e9c6000a2d86c1f4f24c3891c29b2edfffc5df4",
"url": "https://software.frc971.org/Build-Dependencies/wheelhouse/tensorstore-0.1.64-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
@@ -547,6 +567,10 @@
"sha256": "9297c0df9c99445c2412e832e882a7884038a25617c60cea2ad69488d4040d63",
"url": "https://software.frc971.org/Build-Dependencies/wheelhouse/termcolor-2.4.0-py3-none-any.whl"
},
+ "tf_keras==2.17.0": {
+ "sha256": "cc97717e4dc08487f327b0740a984043a9e0123c7a4e21206711669d3ec41c88",
+ "url": "https://software.frc971.org/Build-Dependencies/wheelhouse/tf_keras-2.17.0-py3-none-any.whl"
+ },
"toml==0.10.2": {
"sha256": "806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b",
"url": "https://software.frc971.org/Build-Dependencies/wheelhouse/toml-0.10.2-py2.py3-none-any.whl"