stable-baselines3-contrib
stable-baselines3-contrib copied to clipboard
Speed up when using MaskablePPO
❓ Question
Hi, I'm using MaskablePPO on a powerful computer but the speed of the training doesn't change compared to a normal computer. Is there any option or line of code that increases the speed of training? Thank you,
class customenv(gym.Env):....
env = customenv()
env = ActionMasker(env, mask_fn)
model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=0)
model.learn(1000000)
Checklist
- [X] I have checked that there is no similar issue in the repo
- [X] I have read the documentation
- [X] If code there is, it is minimal and working
- [X] If code there is, it is formatted using the markdown code blocks for both code and stack traces.
Hello,
I'm using MaskablePPO on a powerful computer but the speed of the training doesn't change compared to a normal computer. Is there any option or line of code that increases the speed of training?
Related: https://github.com/DLR-RM/stable-baselines3/issues/1245 and https://github.com/DLR-RM/stable-baselines3/issues/90#issuecomment-659607948 and https://github.com/DLR-RM/stable-baselines3/issues/682
You should probably use multiple envs too, in that case, you should define the action mask function directly in the env, see https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/49#issuecomment-1422869188
Hello,
I'm using MaskablePPO on a powerful computer but the speed of the training doesn't change compared to a normal computer. Is there any option or line of code that increases the speed of training?
Related: DLR-RM/stable-baselines3#1245 and DLR-RM/stable-baselines3#90 (comment) and DLR-RM/stable-baselines3#682
You should probably use multiple envs too, in that case, you should define the action mask function directly in the env, see #49 (comment)
Sorry for opening again, but I face an error when using with custom env. All my environment methods are the same as "InvalidActionEnvDiscrete":
`---------------------------------------------------------------------------
EOFError Traceback (most recent call last)
[<ipython-input-44-67cc0d019c11>](https://localhost:8080/#) in <cell line: 2>()
1 model = MaskablePPO("MlpPolicy", env, verbose=1, tensorboard_log="/content/drive/MyDrive/Colab Notebooks/JOM/test")
----> 2 model.learn(100000)
7 frames
[/usr/local/lib/python3.10/dist-packages/sb3_contrib/ppo_mask/ppo_mask.py](https://localhost:8080/#) in learn(self, total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps, use_masking, progress_bar)
524
525 while self.num_timesteps < total_timesteps:
--> 526 continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, self.n_steps, use_masking)
527
528 if continue_training is False:
[/usr/local/lib/python3.10/dist-packages/sb3_contrib/ppo_mask/ppo_mask.py](https://localhost:8080/#) in collect_rollouts(self, env, callback, rollout_buffer, n_rollout_steps, use_masking)
287 rollout_buffer.reset()
288
--> 289 if use_masking and not is_masking_supported(env):
290 raise ValueError("Environment does not support action masking. Consider using ActionMasker wrapper")
291
[/usr/local/lib/python3.10/dist-packages/sb3_contrib/common/maskable/utils.py](https://localhost:8080/#) in is_masking_supported(env)
31 try:
32 # TODO: add VecEnv.has_attr()
---> 33 env.get_attr(EXPECTED_METHOD_NAME)
34 return True
35 except AttributeError:
[/usr/local/lib/python3.10/dist-packages/stable_baselines3/common/vec_env/subproc_vec_env.py](https://localhost:8080/#) in get_attr(self, attr_name, indices)
171 for remote in target_remotes:
172 remote.send(("get_attr", attr_name))
--> 173 return [remote.recv() for remote in target_remotes]
174
175 def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
[/usr/local/lib/python3.10/dist-packages/stable_baselines3/common/vec_env/subproc_vec_env.py](https://localhost:8080/#) in <listcomp>(.0)
171 for remote in target_remotes:
172 remote.send(("get_attr", attr_name))
--> 173 return [remote.recv() for remote in target_remotes]
174
175 def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
[/usr/lib/python3.10/multiprocessing/connection.py](https://localhost:8080/#) in recv(self)
248 self._check_closed()
249 self._check_readable()
--> 250 buf = self._recv_bytes()
251 return _ForkingPickler.loads(buf.getbuffer())
252
[/usr/lib/python3.10/multiprocessing/connection.py](https://localhost:8080/#) in _recv_bytes(self, maxsize)
412
413 def _recv_bytes(self, maxsize=None):
--> 414 buf = self._recv(4)
415 size, = struct.unpack("!i", buf.getvalue())
416 if size == -1:
[/usr/lib/python3.10/multiprocessing/connection.py](https://localhost:8080/#) in _recv(self, size, read)
381 if n == 0:
382 if remaining == size:
--> 383 raise EOFError
384 else:
385 raise OSError("got end of file during message")
EOFError:`