torchtune
torchtune copied to clipboard
Integrate flex attention
TLDR: Flex attention improves sample packing throughput by 71% for Llama3 8B at max sequence length of 8192, and this effect is more pronounced as sequence length increases.
Context
This is a proposal of what the integration with the prototype API flex_attention available in PyTorch Core's nightlies would look like in torchtune. It requires replacing the nn.scaled_dot_product_attention call in CausalSelfAttention with flex_attention and the usage of BlockMask instead of a standard tensor mask. The primary challenges are handling versioning since this is not available in the latest stable release of torch. I've tried to minimize changes to the attention/transformer modules and the recipes as much as possible besides version gating logic. Most of the core updates occur in PackedDataset and a new collate function for sample packing to construct the BlockMask.
What is FlexAttention?
FlexAttention is an alternative to SDPA that can enable better performance for users that are utilizing different attention variants and are comfortable with using torch.compile. It allows users to specify custom modifications to attention scores within the Fused Scaled Dot Product Attention Kernel. This enables various attention patterns and biases to be implemented efficiently, with potential runtime and memory savings. This includes support for arbitrary masks with flash attention, which will enable performant attention kernels for sample packing (which requires a block causal mask) and vision cross attention masks.
torch.nn.attention.flex_attention.flex_attention(
q,
k,
v,
block_mask=mask,
)
The signature is largely similar to SDPA, except the expected mask is now a BlockMask (see here). This takes in a mask_mod function that will materialize a kernel-level mask ad-hoc, preventing us from holding large 2D attention masks in memory that scale quadratically with sequence length.
Thus, we expect FlexAttention with sample packing to be 1) more memory-efficient and 2) increase throughput due to faster mask construction and ability to use flash attention.
How does sample packing with SDPA compare to Flex?
This gist covers the details of this experiment: https://gist.github.com/RdoubleA/012409f7919973d6ba7e9ca3efd5c237. Losses were observed to be equivalent before and after this change.
TLDR is sample packing with FlexAttention scales significantly better with sequence length:
- Throughput actually decreases as we increased max sequence length from 2048 -> 8192 for SDPA, likely due to quadratic increase in compute for assembling a [b x 8192 x 8192] block causal mask
- At a sequence length of 2048, we observed +11% increase in WPS when switching to FlexAttention.
- At a sequence length of 8192, WPS dropped by 30% for SDPA compared to 2048
- At a sequence length of 8192, FlexAttention was a +71% increase in WPS compared to SDPA
From these observations, we can extrapolate that flex attention will be a dramatic improvement over SDPA with sample packing for context lengths 100k - 1M. For these reasons, it is worthwhile integrating this in our core modules even if it may add some slightly complexity.
Changelog
- Instead of creating a 2D mask on the fly in
PackedDataset's__getitem__, we create a list ofdocument_idsthat indicate which tokens belong to which samples in a packed sequence - We use
document_idsto create aBlockMaskin the batch collater. A new collate function for packing is added, and mask utilities were added toutils/attention_bias.py CausalSelfAttentionwill swap which attention to use depending on if a normla tensor mask is passed or aBlockMask- All flex attention logic is gated by version. I was on
torch==2.5.0.dev20240717+cu124when testing.
Test plan
- [x] Ensure loss curves are identical
- [ ] Add test for packed collate
- [ ] Update attention tests for flex attention
- [ ] Update
PackedDatasettests
cc @drisspg @Chillee @kartikayk
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1193
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:white_check_mark: No Failures
As of commit 4f1eaa472d1f4621d1e07e1fa575fb7877dfa96e with merge base eb92658a360d7a7d4ce1c93bbcf99c99a2e0943b ():
:green_heart: Looks good so far! There are no failures yet. :green_heart:
This comment was automatically generated by Dr. CI and updates every 15 minutes.
One nit: Flex is not intended to be a replacement for sdpa rather it is designed to enable better performance for users that are utilizing different attention variants and are comfortable with using torch.compile
Codecov Report
Attention: Patch coverage is 75.34247% with 54 lines in your changes missing coverage. Please review.
Project coverage is 71.21%. Comparing base (
eb92658) to head (4f1eaa4). Report is 555 commits behind head on main.
Additional details and impacted files
@@ Coverage Diff @@
## main #1193 +/- ##
===========================================
+ Coverage 27.10% 71.21% +44.10%
===========================================
Files 285 287 +2
Lines 13881 14058 +177
===========================================
+ Hits 3763 10011 +6248
+ Misses 10118 4047 -6071
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.