Quiet down printing while training
All the debugging makes it hard to see what is happening.
Change-Id: I67bba4e86f36a2dab6b780d9a85c652df141619f
Signed-off-by: Austin Schuh <austin.linux@gmail.com>
diff --git a/frc971/control_loops/swerve/velocity_controller/train.py b/frc971/control_loops/swerve/velocity_controller/train.py
index 618ab28..721f18c 100644
--- a/frc971/control_loops/swerve/velocity_controller/train.py
+++ b/frc971/control_loops/swerve/velocity_controller/train.py
@@ -424,8 +424,7 @@
)
state = restore_checkpoint(state, workdir)
- state_sharding = nn.get_sharding(state, state.mesh)
- logging.info(state_sharding)
+ logging.debug(nn.get_sharding(state, state.mesh))
replay_buffer_state = state.replay_buffer.init({
'observations1':
@@ -440,9 +439,7 @@
jax.numpy.zeros((problem.num_states, )),
})
- replay_buffer_state_sharding = nn.get_sharding(replay_buffer_state,
- state.mesh)
- logging.info(replay_buffer_state_sharding)
+ logging.debug(nn.get_sharding(replay_buffer_state, state.mesh))
# Number of gradients to accumulate before doing decent.
update_after = FLAGS.batch_size // FLAGS.num_agents