cleanrl icon indicating copy to clipboard operation
cleanrl copied to clipboard

Jax c51 contrib

Open kinalmehta opened this issue 2 years ago • 2 comments

Description

JAX implementation for C51 Implementation for #221

Types of changes

  • [ ] Bug fix
  • [ ] New feature
  • [x] New algorithm
  • [ ] Documentation

Checklist:

  • [x] I've read the CONTRIBUTION guide (required).
  • [x] I have ensured pre-commit run --all-files passes (required).
  • [ ] I have updated the documentation and previewed the changes via mkdocs serve.
  • [ ] I have updated the tests accordingly (if applicable).

If you are adding new algorithms or your change could result in performance difference, you may need to (re-)run tracked experiments. See https://github.com/vwxyzjn/cleanrl/pull/137 as an example PR.

  • [x] I have contacted @vwxyzjn to obtain access to the openrlbenchmark W&B team (required).
  • [ ] I have tracked applicable experiments in openrlbenchmark/cleanrl with --capture-video flag toggled on (required).
  • [ ] I have added additional documentation and previewed the changes via mkdocs serve.
    • [ ] I have explained note-worthy implementation details.
    • [ ] I have explained the logged metrics.
    • [ ] I have added links to the original paper and related papers (if applicable).
    • [ ] I have added links to the PR related to the algorithm.
    • [ ] I have created a table comparing my results against those from reputable sources (i.e., the original paper or other reference implementation).
    • [ ] I have added the learning curves (in PNG format with width=500 and height=300).
    • [ ] I have added links to the tracked experiments.
  • [ ] I have updated the tests accordingly (if applicable).

kinalmehta avatar Jun 29 '22 16:06 kinalmehta

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Updated
cleanrl ✅ Ready (Inspect) Visit Preview Dec 30, 2022 at 5:23PM (UTC)

vercel[bot] avatar Jun 29 '22 16:06 vercel[bot]

Results on classical gym environments can be checked here. https://wandb.ai/kinalmehta/jax-cleanrl/reports/C51-JAX-vs-Pytorch-on-Classical-Gym-Environments--VmlldzoyNDQ3OTk5

We see a speed-up of about ~30% in the JAX version compared to Pytorch.

kinalmehta avatar Sep 01 '22 17:09 kinalmehta

Here is the benchmark report on atari environments https://wandb.ai/kinalmehta/jax-cleanrl/reports/C51-JAX-vs-Pytorch-on-Atari-Environments--VmlldzoyNjkyNDY0

Important observations:

  • BeamRider performance is bad compared PyTorch version
  • Breakout performance almost matches PyTorch variant but is still a bit low comparatively
  • For Pong, the performance matches perfectly for 2 seeds, but the reward remains zero for one of the seeds.

Need to look into more detail about the differences between PyTorch and JAX implementations to fix the above mentioned issues.

kinalmehta avatar Sep 26 '22 03:09 kinalmehta

How does it compare to Dopamine's version?

joaogui1 avatar Sep 26 '22 22:09 joaogui1

How does it compare to Dopamine's version? I haven't checked Dopamine yet. I will have a look and update here, though it might take some time.

kinalmehta avatar Sep 27 '22 03:09 kinalmehta

FYI dopamine has a benchmark, but its x-axis is not the environment steps... Any clue on how we can compare those results? @joaogui1 image

vwxyzjn avatar Sep 27 '22 19:09 vwxyzjn

Atari Fixed

After months of procrastination and debugging various aspects, I finally stumbled upon the cause of performance degradation. The incorrect epsilon value caused this performance degradation. I missed this detail and used the default value $10^{-8}$ from optax. However, the C51-PyTorch version uses ${0.01}/{batch\_size}$. Hoowever I couldn't find any motivation for using this value.

Reading up more on this led to the conclusion that this is a common issue even in NLP and CV as well. More about this hyperparameter can be read here.

Benchmarking classical envs on CPU

I have updated the plots of classical gym environments (CartPole, Acrobot, MountainCar) by benchmarking on CPU. We see significant speed-up compared to pytorch version on CPU.

Comparison with dopamine

Based on the beamrider plot shared above, the below table summarizes the final score comparison

implementation score
dopamine 5000-7000
cleanrl-pytorch ~9500
cleanrl-jax-old ~2500
cleanrl-jax-fixed ~9500

Reports link

  • https://wandb.ai/kinalmehta/cleanrl/reports/Regression-Report-c51_atari_jax--VmlldzozMjM1Mzc1
  • https://wandb.ai/kinalmehta/cleanrl/reports/Regression-Report-c51_jax--VmlldzozMjM1MzM0

Conclusion

The updated plots are available on the above links itself. The PR looks good to be mearged once the documentation is updated. Anything else I am missing here @vwxyzjn?

kinalmehta avatar Dec 29 '22 20:12 kinalmehta

The results look incredible. Great job @kinalmehta. Thanks for chasing down the cause for the issue. The code also look great to me. Feel free to start adding documentation. You should also move the experiments to the openrlbenchmark/cleanrl namespace.

vwxyzjn avatar Dec 29 '22 22:12 vwxyzjn

I've added the documentation, and now I believe this PR is ready for the final review.

kinalmehta avatar Dec 30 '22 11:12 kinalmehta