stable-baselines3-contrib icon indicating copy to clipboard operation
stable-baselines3-contrib copied to clipboard

[Bug] An error in MaskPPO training

Open Yangxiaojun1230 opened this issue 1 year ago • 19 comments

System Info Describe the characteristic of your environment:

Describe how the library was installed: pip sb3-contrib=='1.5.1a9' Python: 3.8.13 Stable-Baselines3: 1.5.1a9 PyTorch: 1.11.0+cu102 GPU Enabled: False Numpy: 1.22.3 Gym: 0.21.0

My training code as below: model = MaskablePPO("MultiInputPolicy", env, gamma=0.4, seed=32, verbose=0) model.learn(300000) My action space is spaces.Discrete() . It seems a problem in torch distribution init(), the input logits had invalid value. And the error happened at uncertain training step.

File ~/anaconda3/envs/stable_base/lib/python3.8/site-packages/sb3_contrib/ppo_mask/ppo_mask.py:579, in MaskablePPO.learn(self, total_timesteps, callback, log_interval, eval_env, eval_freq, n_eval_episodes, tb_log_name, eval_log_path, reset_num_timesteps, use_masking) 576 self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") 577 self.logger.dump(step=self.num_timesteps) --> 579 self.train() 581 callback.on_training_end() 583 return self

File ~/anaconda3/envs/stable_base/lib/python3.8/site-packages/sb3_contrib/ppo_mask/ppo_mask.py:439, in MaskablePPO.train(self) 435 if isinstance(self.action_space, spaces.Discrete): 436 # Convert discrete action from float to long 437 actions = rollout_data.actions.long().flatten() --> 439 values, log_prob, entropy = self.policy.evaluate_actions( 440 rollout_data.observations, 441 actions, 442 action_masks=rollout_data.action_masks, 443 ) 445 values = values.flatten() 446 # Normalize advantage

File ~/anaconda3/envs/stable_base/lib/python3.8/site-packages/sb3_contrib/common/maskable/policies.py:280, in MaskableActorCriticPolicy.evaluate_actions(self, obs, actions, action_masks) 278 distribution = self._get_action_dist_from_latent(latent_pi) 279 if action_masks is not None: --> 280 distribution.apply_masking(action_masks) 281 log_prob = distribution.log_prob(actions) 282 values = self.value_net(latent_vf)

File ~/anaconda3/envs/stable_base/lib/python3.8/site-packages/sb3_contrib/common/maskable/distributions.py:152, in MaskableCategoricalDistribution.apply_masking(self, masks) 150 def apply_masking(self, masks: Optional[np.ndarray]) -> None: 151 assert self.distribution is not None, "Must set distribution parameters" --> 152 self.distribution.apply_masking(masks)

File ~/anaconda3/envs/stable_base/lib/python3.8/site-packages/sb3_contrib/common/maskable/distributions.py:62, in MaskableCategorical.apply_masking(self, masks) 59 logits = self._original_logits 61 # Reinitialize with updated logits ---> 62 super().init(logits=logits) 64 # self.probs may already be cached, so we must force an update 65 self.probs = logits_to_probs(self.logits)

File ~/anaconda3/envs/stable_base/lib/python3.8/site-packages/torch/distributions/categorical.py:64, in Categorical.init(self, probs, logits, validate_args) 62 self._num_events = self._param.size()[-1] 63 batch_shape = self._param.size()[:-1] if self._param.ndimension() > 1 else torch.Size() ---> 64 super(Categorical, self).init(batch_shape, validate_args=validate_args)

File ~/anaconda3/envs/stable_base/lib/python3.8/site-packages/torch/distributions/distribution.py:55, in Distribution.init(self, batch_shape, event_shape, validate_args) 53 valid = constraint.check(value) 54 if not valid.all(): ---> 55 raise ValueError( 56 f"Expected parameter {param} " 57 f"({type(value).name} of shape {tuple(value.shape)}) " 58 f"of distribution {repr(self)} " 59 f"to satisfy the constraint {repr(constraint)}, " 60 f"but found invalid values:\n{value}" 61 ) 62 super(Distribution, self).init()

ValueError: Expected parameter probs (Tensor of shape (64, 400)) of distribution MaskableCategorical(probs: torch.Size([64, 400]), logits: torch.Size([64, 400])) to satisfy the constraint Simplex(), but found invalid values:

Yangxiaojun1230 avatar Jul 05 '22 13:07 Yangxiaojun1230

I checked my env by using "check_env(env)", which failed and output error message. But this env was successfully used before through sb3-ppo. AssertionError: Error while checking key=bin_size_h: The observation returned by the reset() method does not match the given observation space.

My observation spaces and obs declared as below, I couldn't find any problem self.observation_space = spaces.Dict(spaces= { "state_grid":spaces.MultiBinary(self.max_num), "node_placed":spaces.MultiBinary(self.max_inst_num), "cur_node_w": spaces.Box(low=0.0, high=1, shape=(1,), dtype=np.float32), "cur_node_h": spaces.Box(low=0.0, high=1, shape=(1,), dtype=np.float32), "bin_size_w": spaces.Box(low=0.0, high=1, shape=(self.max_num,), dtype=np.float32), "bin_size_h": spaces.Box(low=0.0, high=1, shape=(self.max_num,), dtype=np.float32), "node_size_w": spaces.Box(low=0.0, high=1, shape=(self.max_inst_num,), dtype=np.float32), "node_size_h": spaces.Box(low=0.0, high=1, shape=(self.max_inst_num,), dtype=np.float32), } def get_obs(self): return collections.OrderedDict([ ("state_grid",self.state_bin), ("node_placed",self.node_placed), ("cur_node_w",self.cur_node_w), ("cur_node_h",self.cur_node_h), ("bin_size_w", self.bin_size_w), ("bin_size_h", self.bin_size_h ), ("node_size_w",self.node_size_w), ("node_size_h",self.node_size_h) ])

Yangxiaojun1230 avatar Jul 05 '22 14:07 Yangxiaojun1230

Hey. We do not offer tech support and it is hard to give guidance without further code. If you can replicate this issue with a minimal code and you believe it should be right, please include minimal code to replicate the issue. Meanwhile, I recommend double-checking what comes out of your reset function.

Miffyli avatar Jul 05 '22 17:07 Miffyli

Hi guys, I solved the problem by changing dtype=np.float32 -> dtype=np.float64.

Yangxiaojun1230 avatar Jul 06 '22 01:07 Yangxiaojun1230

Hi guys, The error happened again, and I found the root reason is in torch Categorical class, it will do some constraint check. The failed check is the |sum(probs)-1|<1e-6 . The value of (sum(probs)-1) in the case is -1.6e-6, is this caused by apply_mask() function? Mybe in the function could set the dtype to float64 or change 1e-8 to 1e-6 in below code? Any advice will be appreciate ”HUGE_NEG = th.tensor(-1e8, dtype=self.logits.dtype, device=device) ”

Yangxiaojun1230 avatar Jul 08 '22 03:07 Yangxiaojun1230

I've run into the same problem when using the maskable PPO implementation (this is only relevant in debug mode, where arg validation is enabled by default). Here is a repro for the problem.

The issue seems to be that, for some combination of logits, logits_to_probs will return a value for probs that does not sum to 1 within the tolerance limit (due to precision issues), causing the arg validation constraint for the Categorical class to fail. Normally (without action masking), the Categorical constructor does not run arg validation for probs since it is not present when the constructor is invoked. However, with action masking enabled, the Categorical constructor ends up being called multiple times here, once with probs set, so it runs arg validation on probs and therefore will fail on rare occasion.

One way I've found fix this is to change how the apply_masking method deals with the cached probs here, instead of force updating probs just remove it if it is present before calling the constructor. So basically introduce the following code before the constructor call:

# remove cached probs if present
if 'probs' in self.__dict__:
    delattr(self, 'probs')

# Reinitialize with updated logits
super().__init__(logits=logits)

If this looks reasonable I can make a PR. Thanks!

svolokh avatar Jul 11 '22 02:07 svolokh

@svolokh Thanks for your infomation. In my case, I overwrite the F.softmax(logits,dim=-1, dtype=torch.double) in torch to make it work.

Yangxiaojun1230 avatar Jul 11 '22 05:07 Yangxiaojun1230

@svolokh thanks for the info, would validate_args=False solve also the issue? (probably cleaner than deleting the cached probs)

araffin avatar Jul 18 '22 09:07 araffin

@araffin That does indeed fix the issue as well!

svolokh avatar Jul 19 '22 22:07 svolokh

Good to hear =) then i would be happy to receive a PR that solves this issue ;)

araffin avatar Jul 20 '22 07:07 araffin

Hi, I happened to have the same issue and I did the very same fix as @svolokh in first post. I was quite surprised that @araffin decided that ignoring validation is cleaner solution: I think it definitely isn't! The validation of logits itself should be done, and calling the init method of Categorical class with already (incorrectly!) filled probs is at least suspicious practice. Could you elaborate why do you think, that removing the probs is not a good idea?

dervan avatar Jul 22 '22 23:07 dervan

Hello,

Could you elaborate why do you think, that removing the probs is not a good idea?

The idea behind it is to use a feature that is in the interface of PyTorch, to avoid manual deleting of attribute (which may have side effects).

EDIT: another reason is that deleting the attribute has the same effect: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/81#issuecomment-1179910351

the init method of Categorical class with already (incorrectly!) filled probs is at least suspicious practice.

of course, best would be to solve the root cause of the problem. I haven't looked too much in that problem (didn't wrote that code neither) but from what I get, it is an error due to numerical imprecision, so if you have a better solution that avoid deleting attributes and keep argument validation, I'm of course up for it ;)

araffin avatar Jul 25 '22 09:07 araffin

Hi, I got the same error and found this issue. Is there any reason that the validate_args=False is not released?

hsjung02 avatar May 13 '23 21:05 hsjung02

Hi, I got the same error and found this issue. Is there any reason that the validate_args=False is not released?

Does in work in your case after changing that line of code?

koliber31 avatar Jul 12 '23 19:07 koliber31

Yes, at least it doesn't produce the same error as before. However, I don't know whether it affects the learning performance,

hsjung02 avatar Jul 13 '23 02:07 hsjung02

Yes, at least it doesn't produce the same error as before. However, I don't know whether it affects the learning performance,

Did you check if your agent learns after this change? I mean does it learn at all because after this change this error stopped occuring but agent wasn't able to learn anything.

koliber31 avatar Jul 13 '23 14:07 koliber31

Yes it did and I think must be able to check it fir yourself


보낸 사람: Igor Staniszewski @.> 보낸 날짜: Thursday, July 13, 2023 11:05:41 PM 받는 사람: Stable-Baselines-Team/stable-baselines3-contrib @.> 참조: 정현서 @.>; Comment @.> 제목: Re: [Stable-Baselines-Team/stable-baselines3-contrib] [Bug] An error in MaskPPO training (Issue #81)

Yes, at least it doesn't produce the same error as before. However, I don't know whether it affects the learning performance,

Did you check if your agent learns after this change? I mean does it learn at all because after this change this error stopped occuring but agent wasn't able to learn anything.

— Reply to this email directly, view it on GitHubhttps://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/81#issuecomment-1634308899, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AQRPGPI5EMVTE67UH6XSCSTXP76DLANCNFSM52WFPDQQ. You are receiving this because you commented.Message ID: @.***>

hsjung02 avatar Jul 13 '23 14:07 hsjung02

I did check it and as i said it stopped learning at all. Below are screenshots of learning cureves, these with 150k timesteps and 260k timesteps show learning ended with error (without validate_args=False) and one with 4M timesteps show learning with validate_args=False. 260kSteps 150kSteps wykresy As You can tell after this change it doesn't learn at all. Would you do me a favor and run your learning (if you still have it ofc) just to see if something is happening and tell me the results?

koliber31 avatar Jul 13 '23 14:07 koliber31

Actually I am done with this code(rl, maskableppo) so there is not so much things I can do for you. I remember that the error happened when the learning was almost at the end(e.g., total timesteps=50000 and error happened at 45000). Maybe early stopping can be helpful for you. I didnt tried learning so many timesteps so my experience would not help you.


보낸 사람: Igor Staniszewski @.> 보낸 날짜: Thursday, July 13, 2023 11:34:56 PM 받는 사람: Stable-Baselines-Team/stable-baselines3-contrib @.> 참조: 정현서 @.>; Comment @.> 제목: Re: [Stable-Baselines-Team/stable-baselines3-contrib] [Bug] An error in MaskPPO training (Issue #81)

I did check it and as i said it stopped learning at all. Below are screenshots of learning cureves, these with 150k timesteps and 260k timesteps show learning ended with error (without validate_args=False) and one with 4M timesteps show learning with validate_args=False. [260kSteps]https://user-images.githubusercontent.com/42122590/253310655-b541e269-9bf3-4e73-ae1c-e17f5dcbec82.png [150kSteps]https://user-images.githubusercontent.com/42122590/253310661-ac39c19f-02f0-4862-94f2-2541fa30ac63.png [wykresy]https://user-images.githubusercontent.com/42122590/253310679-80b0afec-3855-4f54-ab08-e1ca5a7ff928.png As You can tell after this change it doesn't learn at all. Would you do me a favor and run your learning (if you still have it ofc) just to see if something is happening and tell me the results?

— Reply to this email directly, view it on GitHubhttps://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/81#issuecomment-1634358280, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AQRPGPO57X5NPPJO7SX5MMTXQABRBANCNFSM52WFPDQQ. You are receiving this because you commented.Message ID: @.***>

hsjung02 avatar Jul 13 '23 14:07 hsjung02

As You can tell after this change it doesn't learn at all. Would you do me a favor and run your learning (if you still have it ofc) just to see if something is happening and tell me the results?

same problem, change validate_args to false i can not learning anymore, had you solve this problem?

yiptsangkin avatar Jul 22 '23 00:07 yiptsangkin