imitation icon indicating copy to clipboard operation
imitation copied to clipboard

AdversarialTrainer: Silent incompatibility with SB3 learning rate schedules

Open shwang opened this issue 4 years ago • 2 comments

h/t @qxcv

AdversarialTrainer.train() will repeatedly call PPO.learn(total_timesteps=gen_batch_size, reset_num_timesteps=False) where gen_batch_size is usually a small number compared to conventional RL training.

Whether or not reset_num_timestep=False, PPO doesn't know the actual number of training timesteps to expect, so for our use case, it will pass the wrong '% training progress remaining" number into the learning rate schedule function.

When reset_num_timestep=False, we enter a particular failure mode where the "% training progress remaining" variable, self._current_progress_remaining, is initialized to 1 - (k) / (k+1) at the start of AdversarialTrainer's kth call to PPO.learn(total_timesteps=gen_batch_size, reset_num_timesteps=False), causing self._current_progress remaining to rapidly approach 0` even if there is actually a lot of training to be done.

PPO clipping scheduling might also be affected by this if it depends on self._current_progressing_remaining.

We should at the very least document and warn that self._current_progress_remaining schedules like this are funky. (Not sure if it is easy to detect non-constant learning rates, etc.)

Maybe we could update SB3 to specify the correct denominator in _update_current_progress_remaining directly. Don't anticipate having the bandwidth to do this something like this soon though.

    def _update_current_progress_remaining(self, num_timesteps: int, total_timesteps: int) -> None:
        self._current_progress_remaining = 1.0 - float(num_timesteps) / float(total_timesteps)
    def _setup_learn(
        self,
        total_timesteps: int,
        …
    ) -> Tuple[int, BaseCallback]:
        …
        if reset_num_timesteps:
            self.num_timesteps = 0
            self._episode_num = 0
        else:
            # Make sure training timesteps are ahead of the internal counter
            total_timesteps += self.num_timesteps
        self._total_timesteps = total_timesteps
        …
        return total_timesteps, callback
    def learn(
        self,
        total_timesteps: int,
        …,
    ) -> "OnPolicyAlgorithm":
        total_timesteps, callback = self._setup_learn(
            total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name
        )
        while self.num_timesteps < total_timesteps:
            self._update_current_progress_remaining(self.num_timesteps, total_timesteps)

shwang avatar Jan 12 '21 06:01 shwang

Well spotted @qxcv!

Could we fix this by calling _setup_learn with the true total timesteps before calling learn? It's a private method so that's not great but it might be OK to change the API to expose it.

AdamGleave avatar Jan 12 '21 11:01 AdamGleave

If it's necessary to change the SB API, it's probably cleaner to either change learn() so that it properly supports this use case (of only training for a few steps at a time), or to add another method that explicitly supports this use case or training in small increments.

There's actually a second bug here where the fps number in logs increases monotonically because _start_time always gets reset even when reset_num_timesteps=False. To me that feels like additional evidence that the current API is designed under incorrect assumptions about what users will do with it (or at least what imitation will do with it), and would benefit from being refactored slightly.

The Keras Model API seems like it's trying to do something similar to SB3's algorithm classes, and it might be helpful to adopt a similar distinction to the one they make between fit() and train_on_batch(). Note that in Keras models, learning rate is tracked via an optimiser passed to the constructor, which the calling code can still manipulate—learn() doesn't try to do all that work inside a single method call.

(also, FYI: my work is not blocked on this because I added duct-tape fixes for both of these issues)

qxcv avatar Jan 12 '21 17:01 qxcv