torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

Add Selective Activation Checkpointing

Open lessw2020 opened this issue 4 months ago • 1 comments

Context

This PR updates activation checkpointing (ac) to support selective layer and selective op activation checkpointing. It preserves the previous options enabled of full or None.
This is controlled in the yaml file via: enable_activation_checkpointing: bool ac_mode: ['full', 'selective'] ac_option: [int, 'op']

if ac_mode is selective then the type of selective is determined based on ac_option.
An integer represents checkpoint every x'th layer (i.e. 2 = checkpoint every other layer, 3 = every third, etc). 'op' means to run selective op ac, where the ac is filtered by the op policy.

Generically on llama-13B, selective AC 2 (every other layer) improved throughput +10% over No AC.

I updated the testing for llama3-8B where I tried to adjust the batch size under each setting to hit around 91GB. I used 8 gpus with the idea of having less impact from model params and more finer grained tuning of the bs size and thus activations. This is not always perfectly possible as activations are chunky, but the net was selective AC 3 was the highest throughput followed by No AC. Sel AC3 was +9% better throughput vs the original Full only option.
For A100, 4090s etc. the actual best combo will vary but the point here is that selective AC provides generally better throughput options over the simple binary of Full (True/False).

Screenshot 2024-04-19 at 1 01 53 PM

Changelog

  • ...

Test plan

This code is largely a port from original source in torchtitan where it has already been tested. However, I ran all 4 styles (none, full, sel ac op, sel ac 2) as shown above. I also verified that the new impl of full matched the memory savings of the previous impl of full.

lessw2020 avatar Apr 17 '24 17:04 lessw2020

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/785

Note: Links to docs will display an error until the docs builds have been completed.

:white_check_mark: No Failures

As of commit 1412378b50c8b84da626bc06e22b7af40acc3466 with merge base a79554e12ec49100ca3d821846d180d46adb4c2b (image): :green_heart: Looks good so far! There are no failures yet. :green_heart:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Apr 17 '24 17:04 pytorch-bot[bot]