stable-baselines3-contrib copied to clipboard
Speed up when using MaskablePPO
❓ Question
Hi, I'm using MaskablePPO on a powerful computer but the speed of the training doesn't change compared to a normal computer. Is there any option or line of code that increases the speed of training? Thank you,
class customenv(gym.Env):....
env = customenv()
env = ActionMasker(env, mask_fn)
model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=0)
- [X] I have checked that there is no similar issue in the repo
- [X] I have read the documentation
- [X] If code there is, it is minimal and working
- [X] If code there is, it is formatted using the markdown code blocks for both code and stack traces.
I'm using MaskablePPO on a powerful computer but the speed of the training doesn't change compared to a normal computer. Is there any option or line of code that increases the speed of training?
Related: and and
You should probably use multiple envs too, in that case, you should define the action mask function directly in the env, see
I'm using MaskablePPO on a powerful computer but the speed of the training doesn't change compared to a normal computer. Is there any option or line of code that increases the speed of training?
Related: DLR-RM/stable-baselines3#1245 and DLR-RM/stable-baselines3#90 (comment) and DLR-RM/stable-baselines3#682
You should probably use multiple envs too, in that case, you should define the action mask function directly in the env, see #49 (comment)
Sorry for opening again, but I face an error when using with custom env. All my environment methods are the same as "InvalidActionEnvDiscrete":
EOFError Traceback (most recent call last)
[<ipython-input-44-67cc0d019c11>](https://localhost:8080/#) in <cell line: 2>()
1 model = MaskablePPO("MlpPolicy", env, verbose=1, tensorboard_log="/content/drive/MyDrive/Colab Notebooks/JOM/test")
----> 2 model.learn(100000)
7 frames
[/usr/local/lib/python3.10/dist-packages/sb3_contrib/ppo_mask/](https://localhost:8080/#) in learn(self, total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps, use_masking, progress_bar)
525 while self.num_timesteps < total_timesteps:
--> 526 continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, self.n_steps, use_masking)
528 if continue_training is False:
[/usr/local/lib/python3.10/dist-packages/sb3_contrib/ppo_mask/](https://localhost:8080/#) in collect_rollouts(self, env, callback, rollout_buffer, n_rollout_steps, use_masking)
287 rollout_buffer.reset()
--> 289 if use_masking and not is_masking_supported(env):
290 raise ValueError("Environment does not support action masking. Consider using ActionMasker wrapper")
[/usr/local/lib/python3.10/dist-packages/sb3_contrib/common/maskable/](https://localhost:8080/#) in is_masking_supported(env)
31 try:
32 # TODO: add VecEnv.has_attr()
---> 33 env.get_attr(EXPECTED_METHOD_NAME)
34 return True
35 except AttributeError:
[/usr/local/lib/python3.10/dist-packages/stable_baselines3/common/vec_env/](https://localhost:8080/#) in get_attr(self, attr_name, indices)
171 for remote in target_remotes:
172 remote.send(("get_attr", attr_name))
--> 173 return [remote.recv() for remote in target_remotes]
175 def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
[/usr/local/lib/python3.10/dist-packages/stable_baselines3/common/vec_env/](https://localhost:8080/#) in <listcomp>(.0)
171 for remote in target_remotes:
172 remote.send(("get_attr", attr_name))
--> 173 return [remote.recv() for remote in target_remotes]
175 def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
[/usr/lib/python3.10/multiprocessing/](https://localhost:8080/#) in recv(self)
248 self._check_closed()
249 self._check_readable()
--> 250 buf = self._recv_bytes()
251 return _ForkingPickler.loads(buf.getbuffer())
[/usr/lib/python3.10/multiprocessing/](https://localhost:8080/#) in _recv_bytes(self, maxsize)
413 def _recv_bytes(self, maxsize=None):
--> 414 buf = self._recv(4)
415 size, = struct.unpack("!i", buf.getvalue())
416 if size == -1:
[/usr/lib/python3.10/multiprocessing/](https://localhost:8080/#) in _recv(self, size, read)
381 if n == 0:
382 if remaining == size:
--> 383 raise EOFError
384 else:
385 raise OSError("got end of file during message")