imitation icon indicating copy to clipboard operation
imitation copied to clipboard

How to properly inserts sb3 `EvalCallback` for `AdversarialTraining`?

Open gunnxx opened this issue 3 years ago • 5 comments

Seems like there is no easy way to insert EvalCallback to AdversarialTraining, no?

AdversarialTraining.train_gen() populates callback argument of sb3.BaseAlgorithm.learn() with self.gen_callback which is either None or reward_wrapper.RewardVecEnvWrapper().

As an additional note, learn_kwargs argument in AdversarialTraining.train_gen() is always None because AdversarialTraining.train() ignores that argument.

I think it is really crucial to monitor the progress of the imitation learning algorithm against true reward function. Using EvalCallback instead of specifying eval_freq and create_eval_env 1) is the recommended way in sb3 and 2) will be a "seamless plug" (because we just need to create eval environment inestad of doing fancy wrapping/unwrapping if we try to infer eval environment from the training environment).

gunnxx avatar Nov 07 '22 14:11 gunnxx

At first glance this seems like a real issue -- I think we first wrote this code for SB2 before the callback API was finalized, and then it just got ported across. We do support our own callback method for train() but that's not serving quite the same purpose.

learn_kwargs seems like we could probably just delete it.

@taufeeque9 do you mind looking into this?

AdamGleave avatar Nov 08 '22 04:11 AdamGleave

Sure!

Seems like adding a gen_callback argument to train() method should resolve this. We can use the sb3.common.callbacks.CallbackList class to call both the callbacks -- reward_wrapper.RewardVecEnvWrapper and gen_callback.

Deleting learn_kwargs seems okay to me too.

I'll create a PR with the above changes.

taufeeque9 avatar Nov 09 '22 22:11 taufeeque9

That change sounds good to me. We may want to consider deprecating our own callback feature, although it may still serve some purpose, so should first check if it's being used anywhere in the codebase that can't be replaced by gen_callback.

AdamGleave avatar Nov 11 '22 23:11 AdamGleave

I checked that callback feature is not used anywhere. It is specific to AdversarialTrainer.train() method and the children (ie. GAIL and AIRL) do not overload the train() hence they do not use it. Can I be the one who make the PR for the changes as mentioned by @taufeeque9 ? It looks trivial to implement.

gunnxx avatar Nov 12 '22 05:11 gunnxx

Deprecating our callback seems okay as it is not being used anywhere in the codebase. Any future use of callback can be handled by gen_callback, just that the code for gen_callback might become messy if it matters that the callback function should fire after training the discriminator.

Sure @gunnxx! You can open a PR and add me and @AdamGleave as reviewers.

taufeeque9 avatar Nov 14 '22 04:11 taufeeque9