Add Selective Activation Checkpointing
Context
This PR updates activation checkpointing (ac) to support selective layer and selective op activation checkpointing.
It preserves the previous options enabled of full or None.
This is controlled in the yaml file via:
enable_activation_checkpointing: bool
ac_mode: ['full', 'selective']
ac_option: [int, 'op']
if ac_mode is selective then the type of selective is determined based on ac_option.
An integer represents checkpoint every x'th layer (i.e. 2 = checkpoint every other layer, 3 = every third, etc).
'op' means to run selective op ac, where the ac is filtered by the op policy.
Generically on llama-13B, selective AC 2 (every other layer) improved throughput +10% over No AC.
I updated the testing for llama3-8B where I tried to adjust the batch size under each setting to hit around 91GB. I used 8 gpus with the idea of having less impact from model params and more finer grained tuning of the bs size and thus activations. This is not always perfectly possible as activations are chunky, but the net was selective AC 3 was the highest throughput followed by No AC. Sel AC3 was +9% better throughput vs the original Full only option.
For A100, 4090s etc. the actual best combo will vary but the point here is that selective AC provides generally better throughput options over the simple binary of Full (True/False).
Changelog
- ...
Test plan
This code is largely a port from original source in torchtitan where it has already been tested. However, I ran all 4 styles (none, full, sel ac op, sel ac 2) as shown above. I also verified that the new impl of full matched the memory savings of the previous impl of full.
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/785
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:white_check_mark: No Failures
As of commit 1412378b50c8b84da626bc06e22b7af40acc3466 with merge base a79554e12ec49100ca3d821846d180d46adb4c2b ():
:green_heart: Looks good so far! There are no failures yet. :green_heart:
This comment was automatically generated by Dr. CI and updates every 15 minutes.