mlx
mlx copied to clipboard
Metal shaders for efficient self attention on large sequences
Proposed changes
Implements metal shaders for:
o = mx.fast.scaled_dot_product_attention(queries, keys, values, scale, mask)
Supports fp16, fp32 dtypes; flexible hidden dimension, currently templated for 64, 80, and 128.
Causal masking for prompt encoding not yet implemented, this shader is focused at present on full self-attention.
Checklist
Put an x in the boxes that apply.
- [x] I have read the CONTRIBUTING document
- [x] I have run
pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes - [x] I have added tests that prove my fix is effective or that my feature works
- [ ] I have updated the necessary documentation (if needed)
Marking as draft, currently working through some numerical issues via separate workflow, and will add CPU side bindings + dispatch, test & docs - sharing a current status and will update this PR
Hi folks,
Attaching some graphs for measured latency on M3 Max and some estimated memory savings per attention block (empirically observed at several data points, graph here obtained via formulas)
Some room for improvement on larger sequences re: latency, with a divergence after ~2300 sequence length, though the memory savings exceeds 1GB ~2k, and is approaching 5GB at 4250 sequence length (SD3 8B use case).
All measurements were with batch size 2, heads = 38, hidden dim = 64, and float32 on M3 Max / 48GB.
@bpkeene left a few minor comments. Could you address? Once updated we can run the tests and get this merged.
Updated with the requested changes, thank you for the prompt review!