torchtitan
torchtitan copied to clipboard
[Do not review] Activation offloading
Stack from ghstack (oldest at bottom):
- -> #467
Current UX
- We use a
saved_tensors_hookscontext manager, which should be wrapped aroundmodule.forward. The context lets us override pack and unpack hooks that are called when saving an activation for backward and using an activation in backward, respectively. See the tutorial for more info. - We expose two main methods for the user from the context:
wait_for_d2handcopy_h2d_async.- By default, the D2H copies for offloading are async and use pinned memory. The user must call
wait_for_d2hto wait on the D2H copies and free the device memory. This should be done after the compute to overlap with has been issued. - By default, the H2D copies are sync. The user must call
copy_h2d_asyncto prefetch the H2D copies as async. This should be done before the compute to overlap with has been issued. - We show an example of this in
apply_acinparallelize_llama.pyusing module hooks.
- By default, the D2H copies for offloading are async and use pinned memory. The user must call
- Together, this means that by default, no GPU memory is saved and that H2D copies are sync. Only by calling the
wait_for_d2hmethod can we save GPU memory, and only by callingcopy_h2d_asyncmethods can we overlap H2D in backward.
Known Problems
- ! Conflict with
split_with_sizes_copy's H2D copy (specific to FSDP2):- FSDP2's all-gather copy-out uses
split_with_sizes_copy, which first issues a mini-H2D copy to send metadata needed for the main copy. - When the CPU issue order is
copy_h2d_asyncfor layeri-1->split_with_sizes_copyfor layeri-> layeribackward compute, the mini-H2D copy forsplit_with_sizes_copyfor layerican get serialized to run after thecopy_h2d_asyncfor layeri-1H2D copies even though they are running in different streams. This prevents thecopy_h2d_asyncfor layeri-1to overlap with layeribackward compute. - For now, this can be worked around with
reshard_after_forward=False. - Trick/hack from @yifuwang : sleep 1 ms in the
offload_streambefore un-offloading (https://fburl.com/perfdoctor/ms47gqvp) --> allows prioritizing thesplit_with_sizes_copyH2D copy - The CPU issue order of
copy_h2d_asyncfor layeri-1->split_with_sizes_copyfor layeri-> layeribackward compute comes from running thecopy_h2d_asyncfor layeri-1using a module full pre-backward hook.
- FSDP2's all-gather copy-out uses
- ! If the user offloads too many activations, the program can become slow and/or freeze. Further, the first few iterations are slow due to
cudaHostAlloccalls warming up the CPU caching allocator. This might be brittle if other parts of the program (e.g. checkpointing) also use pinned memory. If we do notgc.collect()every iteration, the pinned memory does not seem to be freed, so the allocator does not reuse it in subsequent iterations. - ! We do not have a good way to apply a predicate to decide which activation tensors to offload. With the pack hook API, we only see the tensor, not any other information like which op constructed the tensor.
Examples
- Llama3-8B: DP=8, local batch size 1, sequence length 8192, no AC,
reshard_after_forward=False:- Trace: https://fburl.com/perfdoctor/r1yf0lqf
- Reserved memory: 65.67GiB(69.10%)
- WPS: 5,294 MFU: 31.00%
- Llama3-8B: DP=8, local batch size 1, sequence length 8192, no AC,
reshard_after_forward=True:- Trace: https://fburl.com/perfdoctor/qbhr98az
- Reserved memory: 54.01GiB(56.83%)
- WPS: 4,085 MFU: 23.92% (mainly because H2Ds in backward are not overlapped)
- If we use @yifuwang's trick, we can get WPS: 5,073, MFU: 29.70% without changing reserved memory
- Baseline Llama3-8B: DP=8, local batch size 1, sequence length 8192, no AC,
reshard_after_forward=True:- Reserved memory: 78.38GiB(82.48%)
- WPS: 6,341 MFU: 37.13%
For fun:
- Llama3-8B: DP=8, local batch size 1, sequence length 8192, no AC,
reshard_after_forward=True, using @yifuwang 's hack (0.5 ms sleep), offloading 2 large FFN activations per transformer block- Trace: https://fburl.com/perfdoctor/4882hlne
- Reserved memory: 67.13GiB(70.64%)
- WPS: 6,281 MFU: 36.78%
- Llama3-8B: DP=8, local batch size 1, sequence length 8192, no AC,
reshard_after_forward=False, offloading 2 large FFN activations per transformer block- Trace: https://fburl.com/perfdoctor/ejc4mq1r
- Reserved memory: 77.86GiB(81.92%)
- WPS: 6,507 MFU: 38.10%
Note how reshard_after_forward=False with this 2-FFN-activation offloading dominates reshard_after_forward=True without offloading/AC since it has higher WPS but lower memory.
@yifuwang made the good point there may be interference for inter-node collectives since they also use PCIe to send/recv data to/from the NIC, competing for the D2H/H2D activation copies. The testing so far was only intra-node due to lack of compute resources.