Split Soft Actor-Criticl problem out into a class
This sets us up to optimize multiple, different things.
Change-Id: I6756d4ba5af411eba99825447611ead9de3c5dad
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/velocity_controller/lqr_plot.py b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
index 30f5630..f8e9137 100644
--- a/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
+++ b/frc971/control_loops/swerve/velocity_controller/lqr_plot.py
@@ -94,13 +94,13 @@
grid_X = jax.numpy.array(grid_X)
grid_Y = jax.numpy.array(grid_Y)
# Load the training state.
- physics_constants = jax_dynamics.Coefficients()
+ problem = physics.TurretProblem()
rng = jax.random.key(0)
rng, init_rng = jax.random.split(rng)
- state = create_train_state(init_rng, physics_constants,
- FLAGS.q_learning_rate, FLAGS.pi_learning_rate)
+ state = create_train_state(init_rng, problem, FLAGS.q_learning_rate,
+ FLAGS.pi_learning_rate)
state = restore_checkpoint(state, FLAGS.workdir)
if step is not None and state.step == step:
@@ -121,15 +121,15 @@
def compute_q1(X, Y):
return state.q1_apply(
state.params,
- unwrap_angles(jax.numpy.array([X, Y])),
- jax.numpy.array([0.]),
+ observation=state.problem.unwrap_angles(jax.numpy.array([X, Y])),
+ action=jax.numpy.array([0.]),
)[0]
def compute_q2(X, Y):
return state.q2_apply(
state.params,
- unwrap_angles(jax.numpy.array([X, Y])),
- jax.numpy.array([0.]),
+ observation=state.problem.unwrap_angles(jax.numpy.array([X, Y])),
+ action=jax.numpy.array([0.]),
)[0]
def lqr_cost(X, Y):
@@ -137,18 +137,18 @@
return -x.T @ jax.numpy.array(P) @ x
def compute_q(params, x, y):
- X = unwrap_angles(jax.numpy.array([x, y]))
+ X = state.problem.unwrap_angles(jax.numpy.array([x, y]))
return jax.numpy.minimum(
state.q1_apply(
params,
- X,
- jax.numpy.array([0.]),
+ observation=X,
+ action=jax.numpy.array([0.]),
)[0],
state.q2_apply(
params,
- X,
- jax.numpy.array([0.]),
+ observation=X,
+ action=jax.numpy.array([0.]),
)[0])
cost_grid1 = jax.vmap(jax.vmap(compute_q1))(grid_X, grid_Y)
@@ -173,7 +173,7 @@
x = jax.numpy.array([X, Y])
U, _, _, _ = state.pi_apply(rng,
state.params,
- observation=unwrap_angles(x),
+ observation=state.problem.unwrap_angles(x),
deterministic=True)
return U[0]
@@ -189,12 +189,17 @@
U, _, _, _ = state.pi_apply(rng,
params,
- observation=unwrap_angles(X),
+ observation=state.problem.unwrap_angles(X),
deterministic=True)
U_lqr = F @ (goal - X_lqr)
- cost = jax.numpy.minimum(state.q1_apply(params, unwrap_angles(X), U),
- state.q2_apply(params, unwrap_angles(X), U))
+ cost = jax.numpy.minimum(
+ state.q1_apply(params,
+ observation=state.problem.unwrap_angles(X),
+ action=U),
+ state.q2_apply(params,
+ observation=state.problem.unwrap_angles(X),
+ action=U))
U_plot = data.U.at[i, :].set(U)
U_lqr_plot = data.U_lqr.at[i, :].set(U_lqr)
@@ -206,8 +211,8 @@
X = A @ X + B @ U
X_lqr = A @ X_lqr + B @ U_lqr
- reward = data.reward - physics.state_cost(X, U, goal)
- reward_lqr = data.reward_lqr - physics.state_cost(X_lqr, U_lqr, goal)
+ reward = data.reward - state.problem.cost(X, U, goal)
+ reward_lqr = data.reward_lqr - state.problem.cost(X_lqr, U_lqr, goal)
return X, X_lqr, data._replace(
t=t,
@@ -230,10 +235,10 @@
X, X_lqr, data, params = integrate(
Data(
t=jax.numpy.zeros((FLAGS.horizon, )),
- X=jax.numpy.zeros((FLAGS.horizon, physics.NUM_STATES)),
- X_lqr=jax.numpy.zeros((FLAGS.horizon, physics.NUM_STATES)),
- U=jax.numpy.zeros((FLAGS.horizon, physics.NUM_OUTPUTS)),
- U_lqr=jax.numpy.zeros((FLAGS.horizon, physics.NUM_OUTPUTS)),
+ X=jax.numpy.zeros((FLAGS.horizon, state.problem.num_states)),
+ X_lqr=jax.numpy.zeros((FLAGS.horizon, state.problem.num_states)),
+ U=jax.numpy.zeros((FLAGS.horizon, state.problem.num_outputs)),
+ U_lqr=jax.numpy.zeros((FLAGS.horizon, state.problem.num_outputs)),
cost=jax.numpy.zeros((FLAGS.horizon, 1)),
cost_lqr=jax.numpy.zeros((FLAGS.horizon, 1)),
q1_grid=cost_grid1,
diff --git a/frc971/control_loops/swerve/velocity_controller/main.py b/frc971/control_loops/swerve/velocity_controller/main.py
index 7ef687c..a0d248d 100644
--- a/frc971/control_loops/swerve/velocity_controller/main.py
+++ b/frc971/control_loops/swerve/velocity_controller/main.py
@@ -24,6 +24,7 @@
import jax
import tensorflow as tf
from frc971.control_loops.swerve import jax_dynamics
+from frc971.control_loops.swerve.velocity_controller import physics
jax._src.deprecations.accelerate('tracer-hash')
# Enable the compilation cache
@@ -63,8 +64,8 @@
FLAGS.workdir,
)
- physics_constants = jax_dynamics.Coefficients()
- state = train.train(FLAGS.workdir, physics_constants)
+ problem = physics.TurretProblem()
+ state = train.train(FLAGS.workdir, problem)
if __name__ == '__main__':
diff --git a/frc971/control_loops/swerve/velocity_controller/model.py b/frc971/control_loops/swerve/velocity_controller/model.py
index ae67efe..cb6d9ba 100644
--- a/frc971/control_loops/swerve/velocity_controller/model.py
+++ b/frc971/control_loops/swerve/velocity_controller/model.py
@@ -36,8 +36,8 @@
absl.flags.DEFINE_float(
'final_q_learning_rate',
- default=0.00002,
- help='Training learning rate.',
+ default=0.001,
+ help='Fraction of --q_learning_rate to reduce by by the end.',
)
absl.flags.DEFINE_float(
@@ -48,8 +48,8 @@
absl.flags.DEFINE_float(
'final_pi_learning_rate',
- default=0.00002,
- help='Training learning rate.',
+ default=0.001,
+ help='Fraction of --pi_learning_rate to reduce by by the end.',
)
absl.flags.DEFINE_integer(
@@ -92,14 +92,14 @@
activation: Callable = nn.activation.relu
# Max action we can apply
- action_limit: float = 30.0
+ action_limit: float = 1
@nn.compact
def __call__(self,
- observations: ArrayLike,
+ observation: ArrayLike,
deterministic: bool = False,
rng: PRNGKey | None = None):
- x = observations
+ x = observation
# Apply the dense layers
for i, hidden_size in enumerate(self.hidden_sizes):
x = nn.Dense(
@@ -168,9 +168,9 @@
activation: Callable = nn.activation.tanh
@nn.compact
- def __call__(self, observations, actions):
+ def __call__(self, observation: ArrayLike, action: ArrayLike):
# Estimate Q with a simple multi layer dense network.
- x = jax.numpy.hstack((observations, actions))
+ x = jax.numpy.hstack((observation, action))
for i, hidden_size in enumerate(self.hidden_sizes):
x = nn.Dense(
name=f'denselayer{i}',
@@ -185,8 +185,7 @@
class TrainState(flax.struct.PyTreeNode):
- physics_constants: jax_dynamics.CoefficientsType = flax.struct.field(
- pytree_node=False)
+ problem: Problem = flax.struct.field(pytree_node=False)
step: int | jax.Array = flax.struct.field(pytree_node=True)
substep: int | jax.Array = flax.struct.field(pytree_node=True)
@@ -221,23 +220,30 @@
def pi_apply(self,
rng: PRNGKey,
- params,
- observation,
+ params: flax.core.FrozenDict[str, typing.Any],
+ observation: ArrayLike,
deterministic: bool = False):
- return self.pi_apply_fn({'params': params['pi']},
- physics.unwrap_angles(observation),
- deterministic=deterministic,
- rngs={'pi': rng})
+ return self.pi_apply_fn(
+ {'params': params['pi']},
+ observation=self.problem.unwrap_angles(observation),
+ deterministic=deterministic,
+ rngs={'pi': rng})
- def q1_apply(self, params, observation, action):
- return self.q1_apply_fn({'params': params['q1']},
- physics.unwrap_angles(observation), action)
+ def q1_apply(self, params: flax.core.FrozenDict[str, typing.Any],
+ observation: ArrayLike, action: ArrayLike):
+ return self.q1_apply_fn(
+ {'params': params['q1']},
+ observation=self.problem.unwrap_angles(observation),
+ action=action)
- def q2_apply(self, params, observation, action):
- return self.q2_apply_fn({'params': params['q2']},
- physics.unwrap_angles(observation), action)
+ def q2_apply(self, params: flax.core.FrozenDict[str, typing.Any],
+ observation: ArrayLike, action: ArrayLike):
+ return self.q2_apply_fn(
+ {'params': params['q2']},
+ observation=self.problem.unwrap_angles(observation),
+ action=action)
- def pi_apply_gradients(self, step, grads):
+ def pi_apply_gradients(self, step: int, grads):
updates, new_pi_opt_state = self.pi_tx.update(grads, self.pi_opt_state,
self.params)
new_params = optax.apply_updates(self.params, updates)
@@ -249,7 +255,7 @@
pi_opt_state=new_pi_opt_state,
)
- def q_apply_gradients(self, step, grads):
+ def q_apply_gradients(self, step: int, grads):
updates, new_q_opt_state = self.q_tx.update(grads, self.q_opt_state,
self.params)
new_params = optax.apply_updates(self.params, updates)
@@ -291,9 +297,8 @@
)
@classmethod
- def create(cls, *, physics_constants: jax_dynamics.CoefficientsType,
- params, pi_tx, q_tx, alpha_tx, pi_apply_fn, q1_apply_fn,
- q2_apply_fn, **kwargs):
+ def create(cls, *, problem: Problem, params, pi_tx, q_tx, alpha_tx,
+ pi_apply_fn, q1_apply_fn, q2_apply_fn, **kwargs):
"""Creates a new instance with ``step=0`` and initialized ``opt_state``."""
pi_tx = optax.multi_transform(
@@ -343,7 +348,7 @@
length=FLAGS.replay_size)
result = cls(
- physics_constants=physics_constants,
+ problem=problem,
step=0,
substep=0,
params=params,
@@ -357,7 +362,7 @@
q_opt_state=q_opt_state,
alpha_tx=alpha_tx,
alpha_opt_state=alpha_opt_state,
- target_entropy=-physics.NUM_STATES,
+ target_entropy=-problem.num_states,
mesh=mesh,
sharding=sharding,
replicated_sharding=replicated_sharding,
@@ -367,10 +372,12 @@
return jax.device_put(result, replicated_sharding)
-def create_train_state(rng, physics_constants: jax_dynamics.CoefficientsType,
- q_learning_rate, pi_learning_rate):
+def create_train_state(rng: PRNGKey, problem: Problem, q_learning_rate,
+ pi_learning_rate):
"""Creates initial `TrainState`."""
- pi = SquashedGaussianMLPActor(action_space=physics.NUM_OUTPUTS)
+ pi = SquashedGaussianMLPActor(activation=nn.activation.gelu,
+ action_space=problem.num_outputs,
+ action_limit=problem.action_limit)
# We want q1 and q2 to have different network architectures so they pick up differnet things.
q1 = MLPQFunction(activation=nn.activation.gelu, hidden_sizes=[128, 256])
q2 = MLPQFunction(activation=nn.activation.gelu, hidden_sizes=[256, 128])
@@ -381,17 +388,17 @@
pi_params = pi.init(
pi_rng,
- jax.numpy.ones([physics.NUM_UNWRAPPED_STATES]),
+ observation=jax.numpy.ones([problem.num_unwrapped_states]),
)['params']
q1_params = q1.init(
q1_rng,
- jax.numpy.ones([physics.NUM_UNWRAPPED_STATES]),
- jax.numpy.ones([physics.NUM_OUTPUTS]),
+ observation=jax.numpy.ones([problem.num_unwrapped_states]),
+ action=jax.numpy.ones([problem.num_outputs]),
)['params']
q2_params = q2.init(
q2_rng,
- jax.numpy.ones([physics.NUM_UNWRAPPED_STATES]),
- jax.numpy.ones([physics.NUM_OUTPUTS]),
+ observation=jax.numpy.ones([problem.num_unwrapped_states]),
+ action=jax.numpy.ones([problem.num_outputs]),
)['params']
if FLAGS.alpha < 0.0:
@@ -411,7 +418,7 @@
alpha_tx = optax.sgd(learning_rate=q_learning_rate)
result = TrainState.create(
- physics_constants=physics_constants,
+ problem=problem,
params=init_params(rng),
pi_tx=pi_tx,
q_tx=q_tx,
diff --git a/frc971/control_loops/swerve/velocity_controller/physics.py b/frc971/control_loops/swerve/velocity_controller/physics.py
index 9454f14..d7bf7b4 100644
--- a/frc971/control_loops/swerve/velocity_controller/physics.py
+++ b/frc971/control_loops/swerve/velocity_controller/physics.py
@@ -3,67 +3,103 @@
from frc971.control_loops.swerve import dynamics
from absl import logging
from frc971.control_loops.swerve import jax_dynamics
+from flax.typing import PRNGKey
-@partial(jax.jit, static_argnums=[0])
-def xdot_physics(physics_constants: jax_dynamics.CoefficientsType, X, U):
- A = jax.numpy.array([[0., 1.], [0., -36.85154548]])
- B = jax.numpy.array([[0.], [56.08534375]])
+class Problem(object):
- return A @ X + B @ U
+ def __init__(self, num_states: int, num_unwrapped_states: int,
+ num_outputs: int, num_goals: int, action_limit: float):
+ self.num_states = num_states
+ self.num_unwrapped_states = num_unwrapped_states
+ self.num_outputs = num_outputs
+ self.num_goals = num_goals
+ self.action_limit = action_limit
+ def integrate_dynamics(self, X: jax.typing.ArrayLike,
+ U: jax.typing.ArrayLike):
+ m = 2 # RK4 steps per interval
+ dt = 0.005 / m
-def unwrap_angles(X):
- return X
+ def iteration(i, X):
+ weights = jax.numpy.array([[0.0, 0.5, 0.5, 1.0],
+ [1.0, 2.0, 2.0, 1.0]])
+ def rk_iteration(i, val):
+ kx_previous, weighted_sum = val
+ kx = self.xdot(X + dt * weights[0, i] * kx_previous, U)
+ weighted_sum += dt * weights[1, i] * kx / 6.0
+ return (kx, weighted_sum)
-@jax.jit
-def swerve_cost(X: jax.typing.ArrayLike, U: jax.typing.ArrayLike,
- goal: jax.typing.ArrayLike):
-
- Q = jax.numpy.array([[2.77777778, 0.], [0., 0.01]])
- R = jax.numpy.array([[0.00694444]])
-
- return (X).T @ Q @ (X) + U.T @ R @ U
-
-
-@partial(jax.jit, static_argnums=[0])
-def integrate_dynamics(physics_constants: jax_dynamics.CoefficientsType, X, U):
- m = 2 # RK4 steps per interval
- dt = 0.005 / m
-
- def iteration(i, X):
- weights = jax.numpy.array([[0.0, 0.5, 0.5, 1.0], [1.0, 2.0, 2.0, 1.0]])
-
- def rk_iteration(i, val):
- kx_previous, weighted_sum = val
- kx = xdot_physics(physics_constants,
- X + dt * weights[0, i] * kx_previous, U)
- weighted_sum += dt * weights[1, i] * kx / 6.0
- return (kx, weighted_sum)
+ return jax.lax.fori_loop(lower=0,
+ upper=4,
+ body_fun=rk_iteration,
+ init_val=(X, X))[1]
return jax.lax.fori_loop(lower=0,
- upper=4,
- body_fun=rk_iteration,
- init_val=(X, X))[1]
+ upper=m,
+ body_fun=iteration,
+ init_val=X)
- return jax.lax.fori_loop(lower=0, upper=m, body_fun=iteration, init_val=X)
+ def unwrap_angles(self, X: jax.typing.ArrayLike):
+ return X
+
+ @partial(jax.jit, static_argnums=[0])
+ def xdot(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike):
+ raise NotImplemented("xdot not implemented")
+
+ @partial(jax.jit, static_argnums=[0])
+ def cost(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike,
+ goal: jax.typing.ArrayLike):
+ raise NotImplemented("cost not implemented")
+
+ @partial(jax.jit, static_argnums=[0])
+ def random_states(self, rng: PRNGKey, dimensions=None):
+ raise NotImplemented("random_states not implemented")
+
+ #@partial(jax.jit, static_argnums=[0])
+ def random_actions(self, rng: PRNGKey, dimensions=None):
+ """Produces a uniformly random action in the action space."""
+ return jax.random.uniform(
+ rng,
+ (dimensions or FLAGS.num_agents, self.num_outputs),
+ minval=-self.action_limit,
+ maxval=self.action_limit,
+ )
-state_cost = swerve_cost
-NUM_STATES = 2
-NUM_UNWRAPPED_STATES = 2
-NUM_OUTPUTS = 1
-ACTION_LIMIT = 10.0
+class TurretProblem(Problem):
+ def __init__(self):
+ super().__init__(num_states=2,
+ num_unwrapped_states=2,
+ num_outputs=1,
+ num_goals=2,
+ action_limit=30.0)
-def random_states(rng, dimensions=None):
- rng1, rng2 = jax.random.split(rng)
+ @partial(jax.jit, static_argnums=[0])
+ def xdot(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike):
+ A = jax.numpy.array([[0., 1.], [0., -36.85154548]])
+ B = jax.numpy.array([[0.], [56.08534375]])
- return jax.numpy.hstack(
- (jax.random.uniform(rng1, (dimensions or FLAGS.num_agents, 1),
- minval=-1.0,
- maxval=1.0),
- jax.random.uniform(rng2, (dimensions or FLAGS.num_agents, 1),
- minval=-10.0,
- maxval=10.0)))
+ return A @ X + B @ U
+
+ @partial(jax.jit, static_argnums=[0])
+ def cost(self, X: jax.typing.ArrayLike, U: jax.typing.ArrayLike,
+ goal: jax.typing.ArrayLike):
+ Q = jax.numpy.array([[2.77777778, 0.], [0., 0.01]])
+ R = jax.numpy.array([[0.00694444]])
+
+ return X.T @ Q @ X + U.T @ R @ U
+
+ #@partial(jax.jit, static_argnums=[0])
+ def random_states(self, rng: PRNGKey, dimensions=None):
+ rng1, rng2 = jax.random.split(rng)
+
+ return jax.numpy.hstack(
+ (jax.random.uniform(rng1, (dimensions or FLAGS.num_agents, 1),
+ minval=-1.0,
+ maxval=1.0),
+ jax.random.uniform(rng2, (dimensions or FLAGS.num_agents, 1),
+ minval=-10.0,
+ maxval=10.0)))
diff --git a/frc971/control_loops/swerve/velocity_controller/train.py b/frc971/control_loops/swerve/velocity_controller/train.py
index 5002eee..bbe2f5d 100644
--- a/frc971/control_loops/swerve/velocity_controller/train.py
+++ b/frc971/control_loops/swerve/velocity_controller/train.py
@@ -78,16 +78,6 @@
)
-def random_actions(rng, dimensions=None):
- """Produces a uniformly random action in the action space."""
- return jax.random.uniform(
- rng,
- (dimensions or FLAGS.num_agents, physics.NUM_OUTPUTS),
- minval=-physics.ACTION_LIMIT,
- maxval=physics.ACTION_LIMIT,
- )
-
-
def save_checkpoint(state: TrainState, workdir: str):
"""Saves a checkpoint in the workdir."""
# TODO(austin): use orbax directly.
@@ -138,8 +128,12 @@
observation=observations2)
# Compute target network Q values
- q1_pi_target = state.q1_apply(state.target_params, observations2, action2)
- q2_pi_target = state.q2_apply(state.target_params, observations2, action2)
+ q1_pi_target = state.q1_apply(state.target_params,
+ observation=observations2,
+ action=action2)
+ q2_pi_target = state.q2_apply(state.target_params,
+ observation=observations2,
+ action=action2)
q_pi_target = jax.numpy.minimum(q1_pi_target, q2_pi_target)
alpha = jax.numpy.exp(params['logalpha'])
@@ -149,8 +143,8 @@
(q_pi_target - alpha * logp_pi2))
# Compute the starting Q values from the Q network being optimized.
- q1 = state.q1_apply(params, observations1, actions)
- q2 = state.q2_apply(params, observations1, actions)
+ q1 = state.q1_apply(params, observation=observations1, action=actions)
+ q2 = state.q2_apply(params, observation=observations1, action=actions)
# Mean squared error loss against Bellman backup
q1_loss = ((q1 - bellman_backup)**2).mean()
@@ -181,8 +175,12 @@
pi, logp_pi, _, _ = state.pi_apply(rng=rng,
params=params,
observation=observations1)
- q1_pi = state.q1_apply(jax.lax.stop_gradient(params), observations1, pi)
- q2_pi = state.q2_apply(jax.lax.stop_gradient(params), observations1, pi)
+ q1_pi = state.q1_apply(jax.lax.stop_gradient(params),
+ observation=observations1,
+ action=pi)
+ q2_pi = state.q2_apply(jax.lax.stop_gradient(params),
+ observation=observations1,
+ action=pi)
# And compute the Q of that action.
q_pi = jax.numpy.minimum(q1_pi, q2_pi)
@@ -277,16 +275,24 @@
@jax.jit
-def collect_experience(state: TrainState, replay_buffer_state, R: ArrayLike,
+def collect_experience(state: TrainState, replay_buffer_state,
step_rng: PRNGKey, step):
"""Collects experience by simulating."""
pi_rng = jax.random.fold_in(step_rng, step)
pi_rng, initialization_rng = jax.random.split(pi_rng)
+ pi_rng, goal_rng = jax.random.split(pi_rng)
observation = jax.lax.with_sharding_constraint(
- physics.random_states(initialization_rng, FLAGS.num_agents),
+ state.problem.random_states(initialization_rng, FLAGS.num_agents),
state.sharding)
+ R = jax.numpy.hstack((
+ jax.random.uniform(goal_rng, (FLAGS.num_agents, 1),
+ minval=-0.1,
+ maxval=0.1),
+ jax.numpy.zeros((FLAGS.num_agents, 1)),
+ ))
+
def loop(i, val):
"""Runs 1 step of our simulation."""
observation, pi_rng, replay_buffer_state = val
@@ -295,7 +301,7 @@
def true_fn(i):
# We are at the beginning of the process, pick a random action.
- return random_actions(action_rng, FLAGS.num_agents)
+ return state.problem.random_actions(action_rng, FLAGS.num_agents)
def false_fn(i):
# We are past the beginning of the process, use the trained network.
@@ -314,26 +320,22 @@
)
# Compute the destination state.
- observation2 = jax.vmap(lambda o, pi: physics.integrate_dynamics(
- state.physics_constants, o, pi),
- in_axes=(0, 0))(observation, pi_action)
+ observation2 = jax.vmap(
+ lambda o, pi: state.problem.integrate_dynamics(o, pi),
+ in_axes=(0, 0))(observation, pi_action)
# Soft Actor-Critic is designed to maximize reward. LQR minimizes
# cost. There is nothing which assumes anything about the sign of
# the reward, so use the negative of the cost.
- reward = (-jax.vmap(state_cost)(observation2, pi_action, R)).reshape(
- (FLAGS.num_agents, 1))
+ reward = -jax.vmap(state.problem.cost)(
+ X=observation2, U=pi_action, goal=R)
- logging.info('Observation shape: %s', observation.shape)
- logging.info('Observation2 shape: %s', observation2.shape)
- logging.info('Action shape: %s', pi_action.shape)
- logging.info('Action shape: %s', pi_action.shape)
replay_buffer_state = state.replay_buffer.add(
replay_buffer_state, {
'observations1': observation,
'observations2': observation2,
'actions': pi_action,
- 'rewards': reward,
+ 'rewards': reward.reshape((FLAGS.num_agents, 1)),
})
return observation2, pi_rng, replay_buffer_state
@@ -379,9 +381,7 @@
return rng, state, q_loss, pi_loss, alpha_loss
-def train(
- workdir: str, physics_constants: jax_dynamics.CoefficientsType
-) -> train_state.TrainState:
+def train(workdir: str, problem: Problem) -> train_state.TrainState:
"""Trains a Soft Actor-Critic controller."""
rng = jax.random.key(0)
rng, r_rng = jax.random.split(rng)
@@ -394,6 +394,14 @@
'batch_size': FLAGS.batch_size,
'horizon': FLAGS.horizon,
'warmup_steps': FLAGS.warmup_steps,
+ 'steps': FLAGS.steps,
+ 'replay_size': FLAGS.replay_size,
+ 'num_agents': FLAGS.num_agents,
+ 'polyak': FLAGS.polyak,
+ 'gamma': FLAGS.gamma,
+ 'alpha': FLAGS.alpha,
+ 'final_q_learning_rate': FLAGS.final_q_learning_rate,
+ 'final_pi_learning_rate': FLAGS.final_pi_learning_rate,
}
# Setup TrainState
@@ -406,7 +414,7 @@
final_learning_rate=FLAGS.final_pi_learning_rate)
state = create_train_state(
init_rng,
- physics_constants,
+ problem,
q_learning_rate=q_learning_rate,
pi_learning_rate=pi_learning_rate,
)
@@ -415,21 +423,13 @@
state_sharding = nn.get_sharding(state, state.mesh)
logging.info(state_sharding)
- # TODO(austin): Let the goal change.
- R = jax.numpy.hstack((
- jax.random.uniform(rng, (FLAGS.num_agents, 1), minval=0.9, maxval=1.1),
- jax.random.uniform(rng, (FLAGS.num_agents, 1), minval=-0.1,
- maxval=0.1),
- jax.numpy.zeros((FLAGS.num_agents, 1)),
- ))
-
replay_buffer_state = state.replay_buffer.init({
'observations1':
- jax.numpy.zeros((physics.NUM_STATES, )),
+ jax.numpy.zeros((problem.num_states, )),
'observations2':
- jax.numpy.zeros((physics.NUM_STATES, )),
+ jax.numpy.zeros((problem.num_states, )),
'actions':
- jax.numpy.zeros((physics.NUM_OUTPUTS, )),
+ jax.numpy.zeros((problem.num_outputs, )),
'rewards':
jax.numpy.zeros((1, )),
})
@@ -438,19 +438,6 @@
state.mesh)
logging.info(replay_buffer_state_sharding)
- logging.info('Shape: %s, Sharding %s',
- replay_buffer_state.experience['observations1'].shape,
- replay_buffer_state.experience['observations1'].sharding)
- logging.info('Shape: %s, Sharding %s',
- replay_buffer_state.experience['observations2'].shape,
- replay_buffer_state.experience['observations2'].sharding)
- logging.info('Shape: %s, Sharding %s',
- replay_buffer_state.experience['actions'].shape,
- replay_buffer_state.experience['actions'].sharding)
- logging.info('Shape: %s, Sharding %s',
- replay_buffer_state.experience['rewards'].shape,
- replay_buffer_state.experience['rewards'].sharding)
-
# Number of gradients to accumulate before doing decent.
update_after = FLAGS.batch_size // FLAGS.num_agents
@@ -462,7 +449,6 @@
state, replay_buffer_state = collect_experience(
state,
replay_buffer_state,
- R,
step_rng,
step,
)