stable-baselines3-contrib
stable-baselines3-contrib copied to clipboard
[Question] What is the difference between old_distribution and distribution in train function of TRPO
❓ Question
🙏Thanks for the high scalability made by sb3-contrib
I am referring to the MaskablePPO method to add a mask to TRPO. And In the train function of I have found the following code:
with th.no_grad():
# Note: is copy enough, no need for deepcopy?
# If using gSDE and deepcopy, we need to use `old_distribution.distribution`
# directly to avoid PyTorch errors.
old_distribution = copy.copy(self.policy.get_distribution(rollout_data.observations))
distribution = self.policy.get_distribution(rollout_data.observations)
log_prob = distribution.log_prob(actions)
advantages = rollout_data.advantages
if self.normalize_advantage:
advantages = (advantages - advantages.mean()) / (rollout_data.advantages.std() + 1e-8)
# ratio between old and new policy, should be one at the first iteration
ratio = th.exp(log_prob - rollout_data.old_log_prob)
# surrogate policy objective
policy_objective = (advantages * ratio).mean()
# KL divergence
kl_div = kl_divergence(distribution, old_distribution).mean()
❓Does it look like old_distribution and distribution are exactly the same(kl_div here eqs 0), or did I misread something?
By the way, may I also ask if adding action_masks forTRPO requires providing the corresponding masks before calculating the distribution used for kl_div?
🙂Thanks a lot
Checklist
- [X] I have checked that there is no similar issue in the repo
- [X] I have read the documentation
- [X] If code there is, it is minimal and working
- [X] If code there is, it is formatted using the markdown code blocks for both code and stack traces.
https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/25b43266e08ebe258061ac69688d94144799de75/sb3_contrib/trpo/trpo.py#L282
not sure for the second question but probably yes
Thank you for your timely response.
I think I misunderstood the meaning of line 246
stable-baselines3-contrib/sb3_contrib/trpo/trpo.py
For the second question, just providing action masks as you said, at least it runs well in my custom environment
🙂Thank you very much