How to properly inserts sb3 `EvalCallback` for `AdversarialTraining`?
Seems like there is no easy way to insert EvalCallback to AdversarialTraining, no?
AdversarialTraining.train_gen() populates callback argument of sb3.BaseAlgorithm.learn() with self.gen_callback which is either None or reward_wrapper.RewardVecEnvWrapper().
As an additional note, learn_kwargs argument in AdversarialTraining.train_gen() is always None because AdversarialTraining.train() ignores that argument.
I think it is really crucial to monitor the progress of the imitation learning algorithm against true reward function. Using EvalCallback instead of specifying eval_freq and create_eval_env 1) is the recommended way in sb3 and 2) will be a "seamless plug" (because we just need to create eval environment inestad of doing fancy wrapping/unwrapping if we try to infer eval environment from the training environment).
At first glance this seems like a real issue -- I think we first wrote this code for SB2 before the callback API was finalized, and then it just got ported across. We do support our own callback method for train() but that's not serving quite the same purpose.
learn_kwargs seems like we could probably just delete it.
@taufeeque9 do you mind looking into this?
Sure!
Seems like adding a gen_callback argument to train() method should resolve this. We can use the sb3.common.callbacks.CallbackList class to call both the callbacks -- reward_wrapper.RewardVecEnvWrapper and gen_callback.
Deleting learn_kwargs seems okay to me too.
I'll create a PR with the above changes.
That change sounds good to me. We may want to consider deprecating our own callback feature, although it may still serve some purpose, so should first check if it's being used anywhere in the codebase that can't be replaced by gen_callback.
I checked that callback feature is not used anywhere. It is specific to AdversarialTrainer.train() method and the children (ie. GAIL and AIRL) do not overload the train() hence they do not use it. Can I be the one who make the PR for the changes as mentioned by @taufeeque9 ? It looks trivial to implement.
Deprecating our callback seems okay as it is not being used anywhere in the codebase. Any future use of callback can be handled by gen_callback, just that the code for gen_callback might become messy if it matters that the callback function should fire after training the discriminator.
Sure @gunnxx! You can open a PR and add me and @AdamGleave as reviewers.