stable-baselines3-contrib icon indicating copy to clipboard operation
stable-baselines3-contrib copied to clipboard

[Question] What is the difference between old_distribution and distribution in train function of TRPO

Open 0Addicted0 opened this issue 2 months ago • 1 comments

❓ Question

🙏Thanks for the high scalability made by sb3-contrib

I am referring to the MaskablePPO method to add a mask to TRPO. And In the train function of I have found the following code:

    with th.no_grad():
        # Note: is copy enough, no need for deepcopy?
        # If using gSDE and deepcopy, we need to use `old_distribution.distribution`
        # directly to avoid PyTorch errors.
        old_distribution = copy.copy(self.policy.get_distribution(rollout_data.observations))

    distribution = self.policy.get_distribution(rollout_data.observations)
    log_prob = distribution.log_prob(actions)

    advantages = rollout_data.advantages
    if self.normalize_advantage:
        advantages = (advantages - advantages.mean()) / (rollout_data.advantages.std() + 1e-8)

    # ratio between old and new policy, should be one at the first iteration
    ratio = th.exp(log_prob - rollout_data.old_log_prob)

    # surrogate policy objective
    policy_objective = (advantages * ratio).mean()

    # KL divergence
    kl_div = kl_divergence(distribution, old_distribution).mean()

❓Does it look like old_distribution and distribution are exactly the same(kl_div here eqs 0), or did I misread something?

By the way, may I also ask if adding action_masks forTRPO requires providing the corresponding masks before calculating the distribution used for kl_div?

🙂Thanks a lot

Checklist

0Addicted0 avatar Apr 30 '24 15:04 0Addicted0

https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/25b43266e08ebe258061ac69688d94144799de75/sb3_contrib/trpo/trpo.py#L282

not sure for the second question but probably yes

araffin avatar Apr 30 '24 16:04 araffin

Thank you for your timely response.

I think I misunderstood the meaning of line 246

stable-baselines3-contrib/sb3_contrib/trpo/trpo.py

For the second question, just providing action masks as you said, at least it runs well in my custom environment

🙂Thank you very much

0Addicted0 avatar May 01 '24 07:05 0Addicted0