Easy-Transformer icon indicating copy to clipboard operation
Easy-Transformer copied to clipboard

[Draft] Support Flash Attention

Open cmathw opened this issue 1 year ago • 4 comments

Description

Resolves #378 by adding support for torch.nn.functional.scaled_dot_product_attention as found here. This implementation includes FlashAttention-2, as well as, two other alternative (potentially faster) attention implementations. PyTorch attempts to automatically select the most optimal implementation based on inputs. Thank you @alan-cooney for recommending this implementation!

Currently still in draft because the tolerances between model that uses a fast attention implementation and one that does not are a bit high. This is likely due to the fact scaled_dot_product_attention requires casting to float16 (which we then cast back to float32 after doing the fused attention). I will look into this further though and see if there is an improvement to be made here.

Type of change

Please delete options that are not relevant.

  • [x] New feature (non-breaking change which adds functionality)

Checklist:

  • [x] I have commented my code, particularly in hard-to-understand areas
  • [ ] I have made corresponding changes to the documentation
  • [x] My changes generate no new warnings
  • [x] I have added tests that prove my fix is effective or that my feature works
  • [x] New and existing unit tests pass locally with my changes
  • [x] I have not rewritten tests relating to key interfaces which would affect backward compatibility

cmathw avatar Jan 30 '24 21:01 cmathw

Is there any update on this, and any idea on what can be done to get the tolerance to an acceptable level?

bryce13950 avatar Apr 13 '24 18:04 bryce13950

Is there any update on this, and any idea on what can be done to get the tolerance to an acceptable level?

I haven't taken a further look yet, is this something currently blocking other features?

cmathw avatar Apr 25 '24 18:04 cmathw

Nope, I have just been going through PRs and closing out anything that can be closed, and helping get anything else closed out. If you need some help with this, let me know. If I have time to help you out, I would be happy to.

bryce13950 avatar Apr 25 '24 20:04 bryce13950

fails with llama-3

  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/transformer_lens/components.py", line 591, in forward
    z = self.calculate_z_with_sdpa(q, k, v)  # [batch, pos, head_index, d_head]
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/transformer_lens/components.py", line 785, in calculate_z_with_sdpa
    z = F.scaled_dot_product_attention(query, key, value, is_causal=True)
RuntimeError: The size of tensor a (32) must match the size of tensor b (8) at non-singleton dimension 1

likely a group query attention issue, since llama-2 doesn't have the same error.

winglian avatar May 03 '24 04:05 winglian