cleanrl
cleanrl copied to clipboard
Parallel-envs-friendly ppo_continuous_action.py
Description
This PR modifies ppo_continuous_action.py
to make it more parallel-envs-friendly. CC @kevinzakka.
The version of ppo_continuous_action.py
in this PR is different from that in the master
branch in the following ways:
- use a different set of hyperparameters that leverage more simulation environments (e.g., 64 parallel environments) https://github.com/vwxyzjn/cleanrl/blob/703cd3ba1214a15d2fc6ce9157f8c094d627c07b/cleanrl/ppo_continuous_action.py#L37-L71
- use
gym.vector.AsyncVectorEnv
in favor ofgym.vector.SyncVectorEnv
to speed up things more https://github.com/vwxyzjn/cleanrl/blob/703cd3ba1214a15d2fc6ce9157f8c094d627c07b/cleanrl/ppo_continuous_action.py#L163-L165 - apply the normalize wrappers at the parallel envs level instead of individual env level, meaning the running mean and std for the obs and returns will be calculated based on the whole batch of obs and rewards. In my experience, this is usually more preferable than maintaining the normalize wrappers at each sub-env. When N=1, it should not cause any performance difference https://github.com/vwxyzjn/cleanrl/blob/703cd3ba1214a15d2fc6ce9157f8c094d627c07b/cleanrl/ppo_continuous_action.py#L166-L170
- one thing that would be worth trying is to remove the normalize wrappers — it should improve SPS. Or in the case of JAX, maybe re-writing and jitting the normalize wrappers will improve SPS as well.
I also added a JAX variant that reached the same level of performance

Types of changes
- [ ] Bug fix
- [x] New feature
- [ ] New algorithm
- [ ] Documentation
Checklist:
- [ ] I've read the CONTRIBUTION guide (required).
- [ ] I have ensured
pre-commit run --all-files
passes (required). - [ ] I have updated the documentation and previewed the changes via
mkdocs serve
. - [ ] I have updated the tests accordingly (if applicable).
If you are adding new algorithm variants or your change could result in performance difference, you may need to (re-)run tracked experiments. See https://github.com/vwxyzjn/cleanrl/pull/137 as an example PR.
- [ ] I have contacted vwxyzjn to obtain access to the openrlbenchmark W&B team (required).
- [ ] I have tracked applicable experiments in openrlbenchmark/cleanrl with
--capture-video
flag toggled on (required). - [ ] I have added additional documentation and previewed the changes via
mkdocs serve
.- [ ] I have explained note-worthy implementation details.
- [ ] I have explained the logged metrics.
- [ ] I have added links to the original paper and related papers (if applicable).
- [ ] I have added links to the PR related to the algorithm variant.
- [ ] I have created a table comparing my results against those from reputable sources (i.e., the original paper or other reference implementation).
- [ ] I have added the learning curves (in PNG format).
- [ ] I have added links to the tracked experiments.
- [ ] I have updated the overview sections at the docs and the repo
- [ ] I have updated the tests accordingly (if applicable).
The latest updates on your projects. Learn more about Vercel for Git ↗︎
Name | Status | Preview | Updated |
---|---|---|---|
cleanrl | ✅ Ready (Inspect) | Visit Preview | Jan 13, 2023 at 2:25PM (UTC) |
Thank you @vwxyzjn! I'll give this a spin.
hello. Thanks alot for implementing PPO in JAX in such a clean fashion. But, while reproducing the results, i am facing the following issue.
Traceback (most recent call last):
File "/scratch/vaidya/mujoco_sims/gym_mujoco_drones/gym_mujoco_drones/cleanrl_jax_ppo.py", line 199, in <module>
agent_state = TrainState.create(
File "/scratch/vaidya/miniconda3/envs/sbx-gpu/lib/python3.10/site-packages/flax/training/train_state.py", line 127, in create
params['params'] if OVERWRITE_WITH_GRADIENT in params else params
TypeError: argument of type 'AgentParams' is not iterable
Exception ignored in: <function AsyncVectorEnv.__del__ at 0x7f6aa6d89630>
Traceback (most recent call last):
File "/scratch/vaidya/miniconda3/envs/sbx-gpu/lib/python3.10/site-packages/gymnasium/vector/async_vector_env.py", line 549, in __del__
File "/scratch/vaidya/miniconda3/envs/sbx-gpu/lib/python3.10/site-packages/gymnasium/vector/vector_env.py", line 272, in close
File "/scratch/vaidya/miniconda3/envs/sbx-gpu/lib/python3.10/site-packages/gymnasium/vector/async_vector_env.py", line 465, in close_extras
AttributeError: 'NoneType' object has no attribute 'TimeoutError'
Since i am currently new to JAX, i am unable to debug the issue of AgentParams
being not iterable
on my own. I understand that this is a work in progress, but i would appreciate any pointers to solve this.
Thanks