stable-baselines3 icon indicating copy to clipboard operation
stable-baselines3 copied to clipboard

Prioritized experience replay

Open AlexPasqua opened this issue 1 year ago • 15 comments

Description

Implementation of prioritized replay buffer for DQN. Closes #1242

Motivation and Context

  • [x] I have raised an issue to propose this change (required for new features and bug fixes)

In accordance with #1242

Types of changes

  • [ ] Bug fix (non-breaking change which fixes an issue)
  • [x] New feature (non-breaking change which adds functionality)
  • [ ] Breaking change (fix or feature that would cause existing functionality to change)
  • [x] Documentation (update in the documentation)

Checklist

  • [x] I've read the CONTRIBUTION guide (required)
  • [x] I have updated the changelog accordingly (required).
  • [x] My change requires a change to the documentation.
  • [x] I have updated the tests accordingly (required for a bug fix or a new feature).
  • [ ] I have updated the documentation accordingly.
  • [ ] I have opened an associated PR on the SB3-Contrib repository (if necessary)
  • [ ] I have opened an associated PR on the RL-Zoo3 repository (if necessary)
  • [ ] I have reformatted the code using make format (required)
  • [ ] I have checked the codestyle using make check-codestyle and make lint (required)
  • [ ] I have ensured make pytest and make type both pass. (required)
  • [ ] I have checked that the documentation builds using make doc (required)

Note: You can run most of the checks using make commit-checks.

Note: we are using a maximum length of 127 characters per line

AlexPasqua avatar Jul 23 '23 17:07 AlexPasqua

@araffin could you (or anyone) please have a look at the 2 pytype errors? I don't quite understand how to fix them

AlexPasqua avatar Aug 06 '23 15:08 AlexPasqua

Thanks @araffin ! Out of curiosity, may I ask why the switch between torch and numpy for the backend?

AlexPasqua avatar Sep 29 '23 10:09 AlexPasqua

Thanks @araffin ! Out of curiosity, may I ask why the switch between torch and numpy for the backend?

to be consistent with the rest of the buffers and because PyTorch is not needed here (no gpu computation needed).

araffin avatar Sep 29 '23 10:09 araffin

Hello @araffin , as you moved the code to "common", I suppose you plan to make it usable in algorithms other than DQN. At this point, wouldn't it be clearer to put the code into common/buffers.py? Let me know, and in case, I will move it there.

AlexPasqua avatar Sep 30 '23 16:09 AlexPasqua

At this point, wouldn't it be clearer to put the code into common/buffers.py?

yes probably, but the most important thing for now is to test the implementation (performance test, check we can reproduce the results from the paper), document it and add additional tests/doc (for sumtree for instance).

araffin avatar Oct 02 '23 16:10 araffin

performance test, check we can reproduce the results from the paper

After some initial test on Breakout following hyperparameters from the paper, the run didn't improve or worsen DQN performance so far... I will try on other envs (it would be nice if you could help).

araffin avatar Oct 04 '23 19:10 araffin

After some initial test on Breakout following hyperparameters from the paper, the run didn't improve or worsen DQN performance so far... I will try on other envs (it would be nice if you could help).

Thanks for starting to test it! These days I'm travelling, and also writing a paper after work, but I'll try to squeeze some tests in

AlexPasqua avatar Oct 05 '23 06:10 AlexPasqua

@araffin I've also done some initial tests and it looks like PER might lead to a slightly faster convergence, for example on cartpole, but nothing super evident unfortunately. Next I'd like to properly reproduce some of the paper's experiment, but computational power could become a bit of an issue for me

AlexPasqua avatar Nov 02 '23 15:11 AlexPasqua

Just a comment, I've tested this implementation with QR-DQN with Vecenv multiple environment but it fails because of the missing part.

But good job to start the work on it! I hope it will be merged soon! 👍

richardjozsa avatar Nov 30 '23 22:11 richardjozsa

I've just tried validating the implementation on blind cliffwalk and it seems much slower (~an order of magnitude) than the uniform replay buffer. The results below are for a one seed: Screenshot 2024-05-27 at 10 20 10 PM

Not sure why this is. The details for blind cliffwalk are a bit vague from the paper (no code available as well), but I've tried to implement it as close to the description as possible.

Code for the test is in this gist: https://gist.github.com/jbial/105299c00dc3bb7960f0f17f2fc4d6c9

jbial avatar May 28 '24 03:05 jbial

Some update from my part, I just added CNN support for SBX (SB3 + Jax) DQN, and it is 10x faster than the PyTorch equivalent: https://github.com/araffin/sbx/pull/49

That should allow to test and debug things more quickly on Atari (~1h40 for 10M steps instead of 15h =D)

Perf report: https://wandb.ai/openrlbenchmark/sbx?nw=nwuseraraffin (on-going)

araffin avatar Jul 07 '24 18:07 araffin

Some additional update: when trying to plug the PER implementation of this PR inside the Jax DQN implementation, the experience replay was the bottleneck (by a good margin, making things 40x slower...), so I investigated different ways to speed things up.

After playing with many different implementation (pure python, numpy, jax, jax jitted, ...), I decided to re-use the SB2 "SegmentTree" vectorized implementation and also implement proper multi-env support. My current progress is here: https://github.com/araffin/sbx/pull/50

(still debugging, but at least I've got the first sign of life and this implementation is so much faster)

araffin avatar Jul 17 '24 07:07 araffin

Hey @araffin , it is great to hear that. Does SBX/Jax means this much speed improvement?

If you think it is ready for testing I can give a try, just let me know when it is ready to be tested. :)

richardjozsa avatar Jul 17 '24 17:07 richardjozsa

Does SBX/Jax means this much speed improvement?

With the right parameters (see the exact command line argument for the RL Zoo in the OpenRL benchmark organization run on W&B), yes, around 10x faster.

If you think it is ready for testing I can give a try, just let me know when it is ready to be tested. :)

SBX version is ready to be tested but so far, I didn't manage to see any gain from the PER. I also experienced some explosion in the qf value when using multiple env (so there is probably a bug here). I'm also wondering if I need to implement double q-learning (easy) too to compare to the original paper.

araffin avatar Jul 17 '24 19:07 araffin

When I tested this PR I also noticed an explosion in loss, in that time I felt that it is because of the tweaking here and there. and I also noticed that it doesn't give me any advantage over a normal buffer(and I used Dobule DQN, even tried duelling), but I tried to tweak an N-step buffer which had a strong effect on the learning, AFAIK N-step(multi step) is also part of Rainbow and giving substantial part of the success.

The key parts are the distributional, PER and N-step parts, as far as I understand the concept. The others are kinda tasks specific parts and can be detrimental to use them.

richardjozsa avatar Jul 17 '24 20:07 richardjozsa