llm-foundry icon indicating copy to clipboard operation
llm-foundry copied to clipboard

Add support for Flex Attention

Open ShashankMosaicML opened this issue 11 months ago • 0 comments

There are 4 TODOs regarding compiled flex attention that needed to be investigated before checking in. See the tests for more details. TL;DR:

  • I think sequence lengths which are not multiples of 128 are still not supported properly (https://pytorch.org/blog/flexattention/#limitations-and-future-work)
  • Left padding has issues, which can cause issues during generation and inference (maybe related to this because i was seeing the same error: https://github.com/pytorch/pytorch/issues/139064)
  • Changing number of heads between tests for alibi causes errors. This is potentially a minor issue since during actual training or inference, we don't change the number of heads.

Summary: Potentially safe to use for training, not for inference.

Needs this fix for compiling sequence id dependent block masking: https://github.com/pytorch/pytorch/issues/136427, which is is in torch nightly. Use this command to install torch nightly: pip3 install torch==2.6.0.dev20241126+cu124 torchvision==0.20.0.dev20241126+cu124 torchaudio==2.5.0.dev20241126+cu124 --index-url https://download.pytorch.org/whl/nightly/cu124

ShashankMosaicML avatar Nov 27 '24 00:11 ShashankMosaicML