stable-baselines3
stable-baselines3 copied to clipboard
[Question] DictRolloutBuffer seems very slow for larger observations
❓ Question
I am using PPO with image observations. I profiled my training code today and noticed that a large part of the overall training time (second to the actual rollouts) is taken by the _get_samples method in DictRolloutBuffer, particularly the dict comprehension in buffers.py:773. I can see that a large part of that, but not all, is the to_torch method in buffers.py:124, where the data from the observation is copied from the input np.ndarray to a th.Tensor.
To be specific, my observations contain 3 images with one 8bit channel and 168x168 resolution, and I have 16 parallel environments in subprocesses. I run this on a desktop with a AMD Ryzen 32 core CPU and a RTX3090 GPU. The dictcomp in buffers.py:773 takes around 70ms per call, of which 50ms is the to_torch method. This seems quite long to me, especially considering how often this routine is called (after every rollout for each minibatch in each epoch).
In this context I have three questions:
- Are these runtimes normal from what others experienced, and/or is this a known issue?
- Why don't we convert/copy the observations into Tensors already in the
addmethod when receiving them (instead copying all samplesn_epochstimes)? Especially when using a GPU this would be of advantage since then we also avoid repeatedly moving the data to the GPU. - Given I am not the only one having this issue, are there other ideas and suggestions on how we can improve the performance of the DictRolloutBuffer, particularly the dictcomp in
buffers.py:773? I am happy to contribute improving this if others agree it is an actual issue.
Thanks everyone for helping!
Checklist
- [X] I have checked that there is no similar issue in the repo
- [X] I have read the documentation
- [X] If code there is, it is minimal and working
- [X] If code there is, it is formatted using the markdown code blocks for both code and stack traces.
- Are these runtimes normal from what others experienced, and/or is this a known issue?
Already encountered it but never reported it because I'm not sure if it's specifically related to the dict structure. Do you have any evidence that this would work better with the same " large " observations but encapsulated in another space, like the box?
- Why don't we convert/copy the observations into Tensors already in the
addmethod when receiving them (instead copying all samplesn_epochstimes)? Especially when using a GPU this would be of advantage since then we also avoid repeatedly moving the data to the GPU.
Do you mean storing transitions directly on the device? That would be possible, but it can be a big limitation considering that in most configurations, the GPU memory is much smaller than the CPU memory. I see that the VRAM of the RTX3090 is 24 GB. Would this be enough for your application? I wonder if this is an enhancement to be planned, and what the benefit would be.
- Given I am not the only one having this issue, are there other ideas and suggestions on how we can improve the performance of the DictRolloutBuffer, particularly the dictcomp in
buffers.py:773? I am happy to contribute improving this if others agree it is an actual issue.
I agree that this is a point to improve. However, as explained above, I'm not sure if it's related to the dict structure. We would have to check. The problem may be elsewhere. I'd be happy with a solution if you find one.