torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

PPO Performance Improvements

Open SalmanMohammadi opened this issue 3 months ago • 3 comments

Closes #1425

This PR provides various performance improvements to our PPO single device recipe.

Branch Total training time (hours)* Peak memory allocated (GB)
Main 13.1 69.6
This branch 5.4 69.5
This branch + compile 4.6 68.6

*The models were trained over approx. 37M tokens (~65k samples w/max_seq_len=512) on a single A100 GPU.

image Due to the non-determinism of the training process curves may look slightly different.

Changelog:

  • KV-cacheing is now supported during trajectory generation - this significantly speeds up training.
  • generation.generate now only returns logits over the generated tokens rather than the whole sequence - significantly reduces peak memory usage. Tests have been updated.
  • Added profiler support to the recipe.
  • Various changes in trajectory estimation/reward estimation which improve performance.
  • Added parents=True to output_dir.mkdir in our checkpointers. We use nested checkpoint folders for PPO e..g output_dir/policy/, output_dir/value/.
  • The addition of various performance improvements in main since the original baseline means we can bump the default batch size in the configs.
  • Compile support. We have two options here:
    1. Compile the trajectory estimation functions separately - minimizes recompiles but results in a small warmup overhead.
    2. Compile each model using training.compile_model - this results in ~10 recompile warnings, which means we need to increase the compile cache size limit - I've added torch._dynamo.config.cache_size_limit = 16 at the top of the recipe.
Pasted image 20241125141622

I landed on option 2 - it's similar to how we integrate compile with the rest of our recipes, and it eliminates the small warmup overhead. To fully realize compile speedups it's recommended to do a small warm-run of the recipe with compile enabled.

SalmanMohammadi avatar Nov 25 '24 15:11 SalmanMohammadi