imitation
imitation copied to clipboard
AdversarialTrainer: Silent incompatibility with SB3 learning rate schedules
h/t @qxcv
AdversarialTrainer.train()
will repeatedly call PPO.learn(total_timesteps=gen_batch_size, reset_num_timesteps=False)
where gen_batch_size
is usually a small number compared to conventional RL training.
Whether or not reset_num_timestep=False
, PPO
doesn't know the actual number of training timesteps to expect, so for our use case, it will pass the wrong '% training progress remaining" number into the learning rate schedule function.
When reset_num_timestep=False
, we enter a particular failure mode where the "% training progress remaining" variable, self._current_progress_remaining
, is initialized to 1 - (k) / (k+1)
at the start of AdversarialTrainer
's k
th call to PPO.learn(total_timesteps=gen_batch_size, reset_num_timesteps=False)
, causing self._current_progress remaining to rapidly approach
0` even if there is actually a lot of training to be done.
PPO clipping scheduling might also be affected by this if it depends on self._current_progressing_remaining
.
We should at the very least document and warn that self._current_progress_remaining
schedules like this are funky. (Not sure if it is easy to detect non-constant learning rates, etc.)
Maybe we could update SB3 to specify the correct denominator in _update_current_progress_remaining
directly. Don't anticipate having the bandwidth to do this something like this soon though.
def _update_current_progress_remaining(self, num_timesteps: int, total_timesteps: int) -> None:
self._current_progress_remaining = 1.0 - float(num_timesteps) / float(total_timesteps)
def _setup_learn(
self,
total_timesteps: int,
…
) -> Tuple[int, BaseCallback]:
…
if reset_num_timesteps:
self.num_timesteps = 0
self._episode_num = 0
else:
# Make sure training timesteps are ahead of the internal counter
total_timesteps += self.num_timesteps
self._total_timesteps = total_timesteps
…
return total_timesteps, callback
def learn(
self,
total_timesteps: int,
…,
) -> "OnPolicyAlgorithm":
total_timesteps, callback = self._setup_learn(
total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name
)
while self.num_timesteps < total_timesteps:
self._update_current_progress_remaining(self.num_timesteps, total_timesteps)
Well spotted @qxcv!
Could we fix this by calling _setup_learn
with the true total timesteps before calling learn
? It's a private method so that's not great but it might be OK to change the API to expose it.
If it's necessary to change the SB API, it's probably cleaner to either change learn()
so that it properly supports this use case (of only training for a few steps at a time), or to add another method that explicitly supports this use case or training in small increments.
There's actually a second bug here where the fps
number in logs increases monotonically because _start_time
always gets reset even when reset_num_timesteps=False
. To me that feels like additional evidence that the current API is designed under incorrect assumptions about what users will do with it (or at least what imitation
will do with it), and would benefit from being refactored slightly.
The Keras Model
API seems like it's trying to do something similar to SB3's algorithm classes, and it might be helpful to adopt a similar distinction to the one they make between fit()
and train_on_batch()
. Note that in Keras models, learning rate is tracked via an optimiser passed to the constructor, which the calling code can still manipulate—learn()
doesn't try to do all that work inside a single method call.
(also, FYI: my work is not blocked on this because I added duct-tape fixes for both of these issues)