stable-baselines3
stable-baselines3 copied to clipboard
Multiprocessing support for HerReplayBuffer
Description
Make HerReplayBuffer
compatible with Multiprocessing.
Motivation and Context
- [ ] I have raised an issue to propose this change (required for new features and bug fixes)
Alternative to PR #654 with a different implementation of Multiprocessing compatibility for HerReplayBuffer
(Motivation: https://github.com/DLR-RM/stable-baselines3/pull/654#issuecomment-999470139)
Types of changes
- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to change)
- [ ] Documentation (update in the documentation)
Checklist:
- [ ] I've read the CONTRIBUTION guide (required)
- [ ] I have updated the changelog accordingly (required).
- [x] My change requires a change to the documentation.
- [x] I have updated the tests accordingly (required for a bug fix or a new feature).
- [x] I have updated the documentation accordingly.
- [x] I have reformatted the code using
make format
(required) - [x] I have checked the codestyle using
make check-codestyle
andmake lint
(required) - [x] I have ensured
make pytest
andmake type
both pass. (required) - [x] I have checked that the documentation builds using
make doc
(required)
Note: You can run most of the checks using make commit-checks
.
Note: we are using a maximum length of 127 characters per line
Thanks for creating the PR =) This is in fact the cleaner way to do it that I had in mind but no time to invest in... What is missing? More testing? Where can I help?
Hello, Thank you for the PR!
I write a simple config to test this on this branch of my fork of stable-baselines-zoo3.
To reproduce, simply
git clone https://github.com/buoyancy99/rl-baselines3-zoo
cd rl-baselines3-zoo
git checkout vecher
# assume you installed your fork of sb3
python train.py --algo sac --env FetchPush-v1
One minor issue I discovered: if you change the learning_starts
parameter to 100 in this line, or anything too small, you will get the following error:
Traceback (most recent call last):
File "train.py", line 195, in <module>
exp_manager.learn(model)
File "/home/boyuan/Projects/buoyancy99/rl-baselines3-zoo/utils/exp_manager.py", line 202, in learn
model.learn(self.n_timesteps, **kwargs)
File "/home/boyuan/anaconda3/envs/sbp/lib/python3.7/site-packages/stable_baselines3/sac/sac.py", line 301, in learn
reset_num_timesteps=reset_num_timesteps,
File "/home/boyuan/anaconda3/envs/sbp/lib/python3.7/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 374, in learn
self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)
File "/home/boyuan/anaconda3/envs/sbp/lib/python3.7/site-packages/stable_baselines3/sac/sac.py", line 199, in train
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
File "/home/boyuan/anaconda3/envs/sbp/lib/python3.7/site-packages/stable_baselines3/her/her_replay_buffer.py", line 217, in sample
return self._sample_transitions(batch_size, maybe_vec_env=env, online_sampling=True) # pytype: disable=bad-return-type
File "/home/boyuan/anaconda3/envs/sbp/lib/python3.7/site-packages/stable_baselines3/her/her_replay_buffer.py", line 304, in _sample_transitions
episode_indices = np.random.choice(all_episode_indices, batch_size)
File "mtrand.pyx", line 915, in numpy.random.mtrand.RandomState.choice
ValueError: 'a' cannot be empty unless no samples are taken
The same parameter works fine without error for non vectorized environment. Setting learning_starts
to something bigger (512 with 8 envs in my case) will avoid this issue. This margin depends on the number of environments. I believe we should have a better error message or fix here.
I will test the training curve of vectorized her on openai environments to check if the performance is fine.
Thanks for creating the PR =) This is in fact the cleaner way to do it that I had in mind but no time to invest in... What is missing? More testing? Where can I help?
There are two major issues:
-
the offline sampling causes an error when the number of workers is greater than or equal to 2, because of a design problem (not just implementation). I have to rethink this, I'm working on it.
-
the results obtained are different depending on the number of workers, which is not desirable (see curve).
![Screenshot 2022-01-03 at 10 01 32](https://user-images.githubusercontent.com/45557362/147914151-b0d85529-17f3-4a82-b752-e6eee431e295.png)
Environment: BitFlipping 8x8; algorithm: DQN, 10 random seeds; online sampling; goal selection strategy : future; evaluation frequency: 1000; evaluation episodes: 50
the results obtained are different depending on the number of workers, which is not desirable (see curve).
you need to update gradient_steps
(cf. doc, same as https://github.com/DLR-RM/stable-baselines3/issues/699)
I have to rethink this, I'm working on it.
tell me if you have any questions.
you need to update
gradient_steps
(cf. doc, same as #699)
Thanks, working much better when gradient_steps
is proportional to n_envs
:
![Screenshot 2022-01-03 at 11 02 08](https://user-images.githubusercontent.com/45557362/147918501-0ed2b4fd-da1f-479d-8a55-a852cf914f73.png)
(here, gradient_steps = n_envs
)
Currently working on a new implementation of the HerReplayBuffer
.
- The new implementation is more consistent with the others buffers. The goal is to no longer have to deal with the particular case of this buffer during its construction.
- I no longer use an internal buffer (
self._buffer
) : observations, actions etc. are stored inself.observations
,self.actions
etc. - It is no longer necessary to specify a maximum episode size
There is still some work to do:
- Implement offline sampling (I don't know how to do it yet)
- Check that the new implementation gives the same results as the previous one for
n_envs == 1
. - Check that the vectorization does not change the results (
n_envs > 1
) - Check that the new implementation is not slower
And other checks I haven't thought of yet.
Currently working on a new implementation of the HerReplayBuffer.
You mean for both offline and online sampling? The main reason for this implementation is efficiency.
btw, in the worst case, we could just say that offline sampling is not supported for now too.
I no longer use an internal buffer (self._buffer) : observations, actions etc. are stored in self.observations, self.actions etc.
I also agree with that change ;) (but would need to implement backward compatible loading)
You mean for both offline and online sampling? The main reason for this implementation is efficiency.
btw, in the worst case, we could just say that offline sampling is not supported for now too.
That's good to know! If it's just a matter of efficiency for the old (master branch) implementation, I won't spend time on it. The argument will still exist but will raise an error.
I also agree with that change ;) (but would need to implement backward compatible loading)
You mean some kind of method that allows to load transitions from the old implementation of HerReplayBuffer
? It would be implemented to allow to load models trained with an old version of SB3?
It would be implemented to allow to load models trained with an old version of SB3?
Yes, but probably for a separate PR.
The argument will still exist but will raise an error.
but we need to keep it working with single env, does it?
but we need to keep it working with single env, does it?
Making it work for a single environment should not be a problem. I'll do that.
Why do you need to remove termination signal due to timeout in HerReplayBuffer
https://github.com/DLR-RM/stable-baselines3/blob/e9a8979022d7005560d43b7a9c1dc1ba85f7989a/stable_baselines3/her/her_replay_buffer.py#L409
and not in other buffers ?
https://github.com/DLR-RM/stable-baselines3/blob/e9a8979022d7005560d43b7a9c1dc1ba85f7989a/stable_baselines3/common/buffers.py#L586
Things are going well. The code seems to work in all configurations. I have one major concern though: the code is too slow!
import cProfile
from stable_baselines3.common.envs import BitFlippingEnv
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3 import DQN, HerReplayBuffer
with cProfile.Profile() as pr:
env = DummyVecEnv(4 * [lambda: BitFlippingEnv(n_bits=10)])
model = DQN("MultiInputPolicy", env, replay_buffer_class=HerReplayBuffer, learning_starts=100)
model.learn(20000)
pr.print_stats(sort="cumtime")
8102837 function calls (7431115 primitive calls) in 23.612 seconds
Ordered by: cumulative time
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.000 0.000 22.943 22.943 dqn.py:245(learn)
1 0.032 0.032 22.943 22.943 off_policy_algorithm.py:306(learn)
1244 0.109 0.000 13.899 0.011 dqn.py:171(train)
1244 0.286 0.000 11.509 0.009 her_replay_buffer.py:133(sample)
1244 0.075 0.000 10.845 0.009 her_replay_buffer.py:262(_create_virual_trans)
1244 0.223 0.000 10.393 0.008 her_replay_buffer.py:188(sample_goals)
1244 0.091 0.000 9.851 0.008 her_replay_buffer.py:197(<listcomp>)
31100 7.226 0.000 9.760 0.000 her_replay_buffer.py:219(_get_episode_from_uid)
1250 0.082 0.000 9.004 0.007 off_policy_algorithm.py:507(collect_rollouts)
5000 0.047 0.000 5.229 0.001 off_policy_algorithm.py:442(_store_transition)
5000 4.824 0.001 5.042 0.001 her_replay_buffer.py:92(add)
32344 1.208 0.000 2.623 0.000 index_tricks.py:148(__getitem__)
5000 0.013 0.000 1.903 0.000 off_policy_algorithm.py:365(_sample_action)
4975 0.026 0.000 1.888 0.000 dqn.py:215(predict)
4505 0.041 0.000 1.821 0.000 policies.py:307(predict)
5000 0.011 0.000 1.491 0.000 base_vec_env.py:154(step)
Almost half of the learning time is spent on executing this list by comprehension. I am working on a solution.
Because the timeout is handled at sampling time for the classic replay buffer (normally). But we can probably do the same for the online sampling.
But we can probably do the same for the online sampling.
Done
I have one major concern though: the code is too slow!
That's because you removed the max episode length, right?
Btw, was that needed? Now the implementation seems really complex...
Now the implementation seems really complex...
Since my last post, there have been many changes, and the code is considerably simpler and faster.
That's because you removed the max episode length, right?
No, that's not the main reason.
In my opinion, if this new implementation is slower, it is because the interactions are no longer stored according to the structure (episode_idx, trans_idx, env_idx)
, which is quite efficient. My motivation tu use a different structure was not to remove max_episode_length
, even if it is a convenient thing that I do not need it anymore. Here is my thoughts:
When interacting with multiple environments, there is no guarantee that all environments will reach a terminal state at the same time. So:
- If you want to keep the structure
(episode_idx, trans_idx, env_idx)
, then you would have to manage a list of indexes of transitions and episode for each environment. It seems pretty complicated. I didn't even try. - Alternatively, you could use a structure like
(episode_idx, trans_idx)
. When an environment finishes an episode, it starts writing the next one on the next availableepisode_idx
. You just need to keep track of the index of the episode being written for each environment. This is what I did in my first commits of this PR:
https://github.com/qgallouedec/stable-baselines3/blob/7942417881ed46a739565c6617676d466e3287db/stable_baselines3/her/her_replay_buffer.py#L87
This solution works quite well but does not seem completely satisfactory in the sense that it is not consistent with other replay buffers: _buffer
dictionary, or internal buffer for offline sampling, different arguments when building it, different storage for done
etc.
These inconsistencies often require to distinguish the HerReplayBuffer
case outside the buffer, which can make its use difficult.
- So I thought of a third solution: make an effort on the consistency of
HerReplayBuffer
to make it easier to use. Of course, this required to abandon the above-mentioned efficient storage structure. Transitions are no longer stored per episode like in all other buffers. In the current state of the code, I only keep track of the episode start indexes, and the episode size.
Drawback:
- it's a bit slower (2x empirically). Advantages:
- use is much easier, because consistent with other buffers
- the code is more concise (545 lines versus 340, for what it's worth)
- it is no longer necessary to specify a maximum episode size
- it is memory efficient (no need to store 0 for episodes that don't reach the size limit)
(episode_idx, trans_idx, env_idx), then you would have to manage a list of indexes of transitions and episode for each environment. It seems pretty complicated. I didn't even try.
Actually, this structure was never used... I agree we should move to (episode_idx, trans_idx)
This is what I did in my first commits of this PR:
Your first implementation was working, right?
These inconsistencies often require to distinguish the HerReplayBuffer case outside the buffer, which can make its use difficult.
I agree that inconsistencies should be avoided when possible.
it's a bit slower (2x empirically).
2x is not a bit...
use is much easier,
use is easier but it is harder to read, right?
Let me check the current code, we might still reference the more consistent implementation for people interested but I'm not sure we will keep it mainly for the two reasons above (computation time and readability).
Your first implementation was working, right?
Yes. It is still possible to build on these old commits.
2x is not a bit...
Agree; I will give a more precise value.
use is easier but it is harder to read, right?
From my point of view, the code is easier to read.
Let me check the current code, we might still reference the more consistent implementation for people interested but I'm not sure we will keep it mainly for the two reasons above (computation time and readability).
👍
Let me check the current code, we might still reference the more consistent implementation for people interested but I'm not sure we will keep it mainly for the two reasons above (computation time and readability).
so at the end, readability is fine but I think current implementation is not right (see my note for the "future" sampling strategy + comment on subproc env), the budget in the performance check should not be changed, it is usually an indicator that something is wrong.
Agree; I will give a more precise value.
Please do =)
Looking at the test, the +1 for the future strategy does actually make a big difference? (the performance test was failing before even with almost twice more budget)
Looking at the test, the +1 for the future strategy does actually make a big difference? (the performance test was failing before even with almost twice more budget)
Indeed! I didn't expect it.
![Screenshot 2022-01-13 at 18 41 38](https://user-images.githubusercontent.com/45557362/149381434-ad390f03-e05c-40c1-a35f-f53f43d1c88a.png)
Btw, there is no reason that online sampling gives better results. I only used 10 seeds, it might explain the gap.
Btw, there is no reason that online sampling gives better results.
There is. In fact, in my experience, the two are not equivalent (and the online sampling usually but not always yields better results):
- with online sampling, you choose random episodes and then transitions inside those episodes
- with offline sampling, you choose random transitions in the entire buffer, so you are more likely to sample longer episodes I would say
The storage of virtual transitions is also different and I'm not sure if it is equivalent.
In the old implementation, the environment is an attribute of HerReplayBuffer
only to compute the reward. I don't find this ideal because some people might want a reward function other than the environment one.
So I replaced env
by compute_reward
in the constructor. If works perfectly when the environment is a DummyVecEnv
. But, as @araffin suspected, it does not work when it is a SubprocVecEnv
.. See the code below which raises a TypeError
that I can not explain.
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.envs.bit_flipping_env import BitFlippingEnv
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3 import HerReplayBuffer, DQN
def test_her(n_envs, image_obs_space, vec_env_cls):
env_fn = lambda: BitFlippingEnv(n_bits=4, image_obs_space=image_obs_space)
env = make_vec_env(env_fn, n_envs, vec_env_cls=vec_env_cls)
model = DQN("MultiInputPolicy", env, replay_buffer_class=HerReplayBuffer, learning_starts=100)
model.learn(total_timesteps=1000)
if __name__ == "__main__":
test_her(n_envs=1, image_obs_space=True, vec_env_cls=SubprocVecEnv)
As I have very little experience with the multiprocessing library, I ask for help.
Should we still want to pass compute_reward
as an argument?
If so, is it reasonable to handle this in this PR?
But, as @araffin suspected, it does not work when it is a SubprocVecEnv..
yes, because the compute_reward
function is in another process, not accessible directly from the main one.
Btw, please add a separate test, not multiply the number of test by 2 (as done currently) to save time.
because some people might want a reward function other than the environment one.
Sounds like an uncommon case, I would fork SB3 in that case if needed. Or do you have additional examples where it is needed?
yes, because the
compute_reward
function is in another process, not accessible directly from the main one. Btw, please add a separate test, not multiply the number of test by 2 (as done currently) to save time.
You are right, I will reduce the number of tests. However, if you look at the test logs, you'll see that the failures occur only for image_obs_space=True
and vec_env_cls=SubprocVecEnv
which is rather unexpected. So I will at least keep these for the tests.
Sounds like an uncommon case, I would fork SB3 in that case if needed. Or do you have additional examples where it is needed?
Indeed. I find it more intuitive to pass only compute_reward
as an argument. But I'll stick to the core of the PR.
Here we compare old (current master branch) and new implementation (of this fork).
![Screenshot 2022-01-18 at 13 33 15](https://user-images.githubusercontent.com/45557362/149938246-71e062ad-bccd-43de-af56-ac6f36b5fecf.png)
Results are obtained using 10 random seeds on the BitFlipping environment with 6 bits. The following hyperparameters were used:
Hyperparameters | Value |
---|---|
target_update_interval |
500 (DQN and SAC only) |
exploration_final_eps |
0.02 (DQN only) |
n_sampled_goal |
2 |
learning_rate |
5e-4 |
train_freq |
1 |
gradient_steps |
1 |
learning_starts |
100 |
batch_size |
32 |
buffer_size |
1e5 |
Evaluation | Value |
---|---|
total_timesteps |
8000 |
n_eval_episodes |
50 |
eval_freq |
400 |
The curves do not overlap exactly. The new implementation behaves sometimes better, sometimes worse. Now we have to determine the reason. Maybe it is just a statistical effect. I will try to increase the number of tests for the most different configurations.
Maybe it is just a statistical effect.
Most probable yes. I would also do some testing with harder env (cf. the RL Zoo with highway env and then other like pick and place).
But most important right now is to quantify how much slower it is.
Most probable yes.
I am currently running the same experiments with more seeds to definitively answer this question.
But most important right now is to quantify how much slower it is.
Then I'll do that. If it's too slow and you don't want to merge, I'll continue to maintain this branch until a faster version is available.
I would also do some testing with harder env (cf. the RL Zoo with highway env and then other like pick and place).
Finally I'll do that.
With 30 seeds. All settings seem correct except for online sampling with the future goal selection strategy. I will investigate.
![Screenshot 2022-01-19 at 20 59 34](https://user-images.githubusercontent.com/45557362/150204825-38015f52-7887-4097-912d-d0410ba28674.png)
New implementation is slower. Running trainings with RL Zoo with 20000 timesteps:
Algorithm | Environment | Old | New | Realtive duration |
---|---|---|---|---|
TQC | parking-v0 |
250 | 253 | + 1.2% |
TQC | FetchReach-v1 |
274 | 378 | + 37% |
TQC | FetchPush-v1 |
411 | 879 | + 113% |
TQC | FetchPickAndPlace-v1 |
364 | 876 | + 140% |
TQC | FetchSlide-v1 |
366 | 869 | + 137% |