torchtune
torchtune copied to clipboard
PPO Performance Improvements
Closes #1425
This PR provides various performance improvements to our PPO single device recipe.
Branch | Total training time (hours)* | Peak memory allocated (GB) |
---|---|---|
Main | 13.1 | 69.6 |
This branch | 5.4 | 69.5 |
This branch + compile | 4.6 | 68.6 |
*The models were trained over approx. 37M tokens (~65k samples w/max_seq_len=512
) on a single A100 GPU.
Changelog:
- KV-cacheing is now supported during trajectory generation - this significantly speeds up training.
-
generation.generate
now only returns logits over the generated tokens rather than the whole sequence - significantly reduces peak memory usage. Tests have been updated. - Added profiler support to the recipe.
- Various changes in trajectory estimation/reward estimation which improve performance.
- Added
parents=True
tooutput_dir.mkdir
in our checkpointers. We use nested checkpoint folders for PPO e..goutput_dir/policy/
,output_dir/value/
. - The addition of various performance improvements in main since the original baseline means we can bump the default batch size in the configs.
- Compile support. We have two options here:
- Compile the trajectory estimation functions separately - minimizes recompiles but results in a small warmup overhead.
- Compile each model using
training.compile_model
- this results in ~10 recompile warnings, which means we need to increase the compile cache size limit - I've addedtorch._dynamo.config.cache_size_limit = 16
at the top of the recipe.
I landed on option 2 - it's similar to how we integrate compile with the rest of our recipes, and it eliminates the small warmup overhead. To fully realize compile speedups it's recommended to do a small warm-run of the recipe with compile enabled.