torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

[Do not review] Activation offloading

Open awgu opened this issue 1 year ago • 2 comments

Stack from ghstack (oldest at bottom):

  • -> #467

Current UX

  • We use a saved_tensors_hooks context manager, which should be wrapped around module.forward. The context lets us override pack and unpack hooks that are called when saving an activation for backward and using an activation in backward, respectively. See the tutorial for more info.
  • We expose two main methods for the user from the context: wait_for_d2h and copy_h2d_async.
    • By default, the D2H copies for offloading are async and use pinned memory. The user must call wait_for_d2h to wait on the D2H copies and free the device memory. This should be done after the compute to overlap with has been issued.
    • By default, the H2D copies are sync. The user must call copy_h2d_async to prefetch the H2D copies as async. This should be done before the compute to overlap with has been issued.
    • We show an example of this in apply_ac in parallelize_llama.py using module hooks.
  • Together, this means that by default, no GPU memory is saved and that H2D copies are sync. Only by calling the wait_for_d2h method can we save GPU memory, and only by calling copy_h2d_async methods can we overlap H2D in backward.

Known Problems

  • ! Conflict with split_with_sizes_copy's H2D copy (specific to FSDP2):
    • FSDP2's all-gather copy-out uses split_with_sizes_copy, which first issues a mini-H2D copy to send metadata needed for the main copy.
    • When the CPU issue order is copy_h2d_async for layer i-1 -> split_with_sizes_copy for layer i -> layer i backward compute, the mini-H2D copy for split_with_sizes_copy for layer i can get serialized to run after the copy_h2d_async for layer i-1 H2D copies even though they are running in different streams. This prevents the copy_h2d_async for layer i-1 to overlap with layer i backward compute.
    • For now, this can be worked around with reshard_after_forward=False.
    • Trick/hack from @yifuwang : sleep 1 ms in the offload_stream before un-offloading (https://fburl.com/perfdoctor/ms47gqvp) --> allows prioritizing the split_with_sizes_copy H2D copy
    • The CPU issue order of copy_h2d_async for layer i-1 -> split_with_sizes_copy for layer i -> layer i backward compute comes from running the copy_h2d_async for layer i-1 using a module full pre-backward hook.
  • ! If the user offloads too many activations, the program can become slow and/or freeze. Further, the first few iterations are slow due to cudaHostAlloc calls warming up the CPU caching allocator. This might be brittle if other parts of the program (e.g. checkpointing) also use pinned memory. If we do not gc.collect() every iteration, the pinned memory does not seem to be freed, so the allocator does not reuse it in subsequent iterations.
  • ! We do not have a good way to apply a predicate to decide which activation tensors to offload. With the pack hook API, we only see the tensor, not any other information like which op constructed the tensor.

Examples

  • Llama3-8B: DP=8, local batch size 1, sequence length 8192, no AC, reshard_after_forward=False:
    • Trace: https://fburl.com/perfdoctor/r1yf0lqf
    • Reserved memory: 65.67GiB(69.10%)
    • WPS: 5,294 MFU: 31.00%
  • Llama3-8B: DP=8, local batch size 1, sequence length 8192, no AC, reshard_after_forward=True:
    • Trace: https://fburl.com/perfdoctor/qbhr98az
    • Reserved memory: 54.01GiB(56.83%)
    • WPS: 4,085 MFU: 23.92% (mainly because H2Ds in backward are not overlapped)
    • If we use @yifuwang's trick, we can get WPS: 5,073, MFU: 29.70% without changing reserved memory
  • Baseline Llama3-8B: DP=8, local batch size 1, sequence length 8192, no AC, reshard_after_forward=True:
    • Reserved memory: 78.38GiB(82.48%)
    • WPS: 6,341 MFU: 37.13%

awgu avatar Jul 18 '24 19:07 awgu

For fun:

  • Llama3-8B: DP=8, local batch size 1, sequence length 8192, no AC, reshard_after_forward=True, using @yifuwang 's hack (0.5 ms sleep), offloading 2 large FFN activations per transformer block
    • Trace: https://fburl.com/perfdoctor/4882hlne
    • Reserved memory: 67.13GiB(70.64%)
    • WPS: 6,281 MFU: 36.78%
  • Llama3-8B: DP=8, local batch size 1, sequence length 8192, no AC, reshard_after_forward=False, offloading 2 large FFN activations per transformer block
    • Trace: https://fburl.com/perfdoctor/ejc4mq1r
    • Reserved memory: 77.86GiB(81.92%)
    • WPS: 6,507 MFU: 38.10%

Note how reshard_after_forward=False with this 2-FFN-activation offloading dominates reshard_after_forward=True without offloading/AC since it has higher WPS but lower memory.

awgu avatar Jul 18 '24 22:07 awgu

@yifuwang made the good point there may be interference for inter-node collectives since they also use PCIe to send/recv data to/from the NIC, competing for the D2H/H2D activation copies. The testing so far was only intra-node due to lack of compute resources.

awgu avatar Jul 19 '24 14:07 awgu