Easy-Transformer
Easy-Transformer copied to clipboard
[Draft] Support Flash Attention
Description
Resolves #378 by adding support for torch.nn.functional.scaled_dot_product_attention
as found here. This implementation includes FlashAttention-2, as well as, two other alternative (potentially faster) attention implementations. PyTorch attempts to automatically select the most optimal implementation based on inputs. Thank you @alan-cooney for recommending this implementation!
Currently still in draft because the tolerances between model that uses a fast attention implementation and one that does not are a bit high. This is likely due to the fact scaled_dot_product_attention
requires casting to float16 (which we then cast back to float32 after doing the fused attention). I will look into this further though and see if there is an improvement to be made here.
Type of change
Please delete options that are not relevant.
- [x] New feature (non-breaking change which adds functionality)
Checklist:
- [x] I have commented my code, particularly in hard-to-understand areas
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my feature works
- [x] New and existing unit tests pass locally with my changes
- [x] I have not rewritten tests relating to key interfaces which would affect backward compatibility
Is there any update on this, and any idea on what can be done to get the tolerance to an acceptable level?
Is there any update on this, and any idea on what can be done to get the tolerance to an acceptable level?
I haven't taken a further look yet, is this something currently blocking other features?
Nope, I have just been going through PRs and closing out anything that can be closed, and helping get anything else closed out. If you need some help with this, let me know. If I have time to help you out, I would be happy to.
fails with llama-3
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/transformer_lens/components.py", line 591, in forward
z = self.calculate_z_with_sdpa(q, k, v) # [batch, pos, head_index, d_head]
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/transformer_lens/components.py", line 785, in calculate_z_with_sdpa
z = F.scaled_dot_product_attention(query, key, value, is_causal=True)
RuntimeError: The size of tensor a (32) must match the size of tensor b (8) at non-singleton dimension 1
likely a group query attention issue, since llama-2 doesn't have the same error.