Confused that the process of "Trains the generator to maximize the discriminator loss"
Problem
Hi, the imitation is a great project!
Currently, I am training GAIL algorithm, and the learner network is PPO in SB3. I have questions about the training process for imitation\GAIL\train_gen
def train():
for r in tqdm.tqdm(range(0, n_rounds), desc="round"):
self.train_gen(self.gen_train_timesteps) ######## Confused!!!
for _ in range(self.n_disc_updates_per_round):
with networks.training(self.reward_train):
# switch to training mode (affects dropout, normalization)
self.train_disc()
In the above source code, the self.train_gen calls the learn function in SB3\on_policy_algorithm. The learn function in SB3\on_policy_algorithm updates the PPO's actor network and critic network.
So I am confused that the training of the generator is no different from that of the PPO, and this process has nothing to do with the discriminator.
Can you explain how it work that "Trains the generator to maximize the discriminator loss" in imitation.algorithms.adversarial.gail.GAIL\train_gen in the code of imitation?
Looking forward to your reply!
Sincerely
Generally, the generator's loss is the negative of the discriminator's score for state and action(i.e. G_loss = - D(s,a)), but I find that the generator (SB3's PPO actor net) is still optimized by SB3's PPO critic net in imitation\gail.
https://github.com/HumanCompatibleAI/imitation/blob/df2627446b7457758a0b09fa74a1fcb19403a236/src/imitation/algorithms/adversarial/common.py#L447-L452
In other words, I don't find the generator's G_loss= -1* discriminator(i.e. G_loss = - D(s,a)) and then the G_loss.backward().
I saw the unclear description in https://github.com/HumanCompatibleAI/imitation/issues/635#issuecomment-1329630844. Can you explain exactly how the generator in imitation\gail is updated by the discriminator?
This is important to me. Thank you for your time. @AdamGleave @ernestum @shwang @dfilan
@Liuzy0908 , Here is how I understand it:
The environment is wrapped when you instantiate GAIL so that the reward used in the .learn method corresponds to the GAIL generator objective. See here:
https://github.com/HumanCompatibleAI/imitation/blob/df2627446b7457758a0b09fa74a1fcb19403a236/src/imitation/algorithms/adversarial/common.py#L225C6-L233C14
So even though we are still using the PPO .learn method, we are doing a policy gradient step on the GAIL objective function