torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

Integrate flex attention

Open RdoubleA opened this issue 1 year ago • 2 comments

TLDR: Flex attention improves sample packing throughput by 71% for Llama3 8B at max sequence length of 8192, and this effect is more pronounced as sequence length increases.

Context

This is a proposal of what the integration with the prototype API flex_attention available in PyTorch Core's nightlies would look like in torchtune. It requires replacing the nn.scaled_dot_product_attention call in CausalSelfAttention with flex_attention and the usage of BlockMask instead of a standard tensor mask. The primary challenges are handling versioning since this is not available in the latest stable release of torch. I've tried to minimize changes to the attention/transformer modules and the recipes as much as possible besides version gating logic. Most of the core updates occur in PackedDataset and a new collate function for sample packing to construct the BlockMask.

What is FlexAttention?

FlexAttention is an alternative to SDPA that can enable better performance for users that are utilizing different attention variants and are comfortable with using torch.compile. It allows users to specify custom modifications to attention scores within the Fused Scaled Dot Product Attention Kernel. This enables various attention patterns and biases to be implemented efficiently, with potential runtime and memory savings. This includes support for arbitrary masks with flash attention, which will enable performant attention kernels for sample packing (which requires a block causal mask) and vision cross attention masks.

torch.nn.attention.flex_attention.flex_attention(
    q,
    k,
    v,
    block_mask=mask,
)

The signature is largely similar to SDPA, except the expected mask is now a BlockMask (see here). This takes in a mask_mod function that will materialize a kernel-level mask ad-hoc, preventing us from holding large 2D attention masks in memory that scale quadratically with sequence length.

Thus, we expect FlexAttention with sample packing to be 1) more memory-efficient and 2) increase throughput due to faster mask construction and ability to use flash attention.

How does sample packing with SDPA compare to Flex?

This gist covers the details of this experiment: https://gist.github.com/RdoubleA/012409f7919973d6ba7e9ca3efd5c237. Losses were observed to be equivalent before and after this change.

TLDR is sample packing with FlexAttention scales significantly better with sequence length:

  • Throughput actually decreases as we increased max sequence length from 2048 -> 8192 for SDPA, likely due to quadratic increase in compute for assembling a [b x 8192 x 8192] block causal mask
  • At a sequence length of 2048, we observed +11% increase in WPS when switching to FlexAttention.
  • At a sequence length of 8192, WPS dropped by 30% for SDPA compared to 2048
  • At a sequence length of 8192, FlexAttention was a +71% increase in WPS compared to SDPA

From these observations, we can extrapolate that flex attention will be a dramatic improvement over SDPA with sample packing for context lengths 100k - 1M. For these reasons, it is worthwhile integrating this in our core modules even if it may add some slightly complexity.

image

Changelog

  • Instead of creating a 2D mask on the fly in PackedDataset's __getitem__, we create a list of document_ids that indicate which tokens belong to which samples in a packed sequence
  • We use document_ids to create a BlockMask in the batch collater. A new collate function for packing is added, and mask utilities were added to utils/attention_bias.py
  • CausalSelfAttention will swap which attention to use depending on if a normla tensor mask is passed or a BlockMask
  • All flex attention logic is gated by version. I was on torch==2.5.0.dev20240717+cu124 when testing.

Test plan

  • [x] Ensure loss curves are identical
  • [ ] Add test for packed collate
  • [ ] Update attention tests for flex attention
  • [ ] Update PackedDataset tests

cc @drisspg @Chillee @kartikayk

RdoubleA avatar Jul 17 '24 19:07 RdoubleA

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1193

Note: Links to docs will display an error until the docs builds have been completed.

:white_check_mark: No Failures

As of commit 4f1eaa472d1f4621d1e07e1fa575fb7877dfa96e with merge base eb92658a360d7a7d4ce1c93bbcf99c99a2e0943b (image): :green_heart: Looks good so far! There are no failures yet. :green_heart:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Jul 17 '24 19:07 pytorch-bot[bot]

One nit: Flex is not intended to be a replacement for sdpa rather it is designed to enable better performance for users that are utilizing different attention variants and are comfortable with using torch.compile

drisspg avatar Jul 22 '24 22:07 drisspg

Codecov Report

Attention: Patch coverage is 75.34247% with 54 lines in your changes missing coverage. Please review.

Project coverage is 71.21%. Comparing base (eb92658) to head (4f1eaa4). Report is 555 commits behind head on main.

Files with missing lines Patch % Lines
torchtune/modules/attention_utils.py 62.50% 21 Missing :warning:
tests/torchtune/modules/test_attention_utils.py 69.69% 20 Missing :warning:
tests/torchtune/data/test_collate.py 70.00% 6 Missing :warning:
recipes/full_finetune_distributed.py 0.00% 2 Missing :warning:
recipes/full_finetune_single_device.py 0.00% 1 Missing :warning:
recipes/lora_finetune_distributed.py 0.00% 1 Missing :warning:
recipes/lora_finetune_single_device.py 0.00% 1 Missing :warning:
recipes/qat_distributed.py 0.00% 1 Missing :warning:
torchtune/utils/logging.py 88.88% 1 Missing :warning:
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1193       +/-   ##
===========================================
+ Coverage   27.10%   71.21%   +44.10%     
===========================================
  Files         285      287        +2     
  Lines       13881    14058      +177     
===========================================
+ Hits         3763    10011     +6248     
+ Misses      10118     4047     -6071     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov-commenter avatar Sep 07 '24 00:09 codecov-commenter