mlx icon indicating copy to clipboard operation
mlx copied to clipboard

Metal shaders for efficient self attention on large sequences

Open bpkeene opened this issue 1 year ago • 1 comments

Proposed changes

Implements metal shaders for:

o = mx.fast.scaled_dot_product_attention(queries, keys, values, scale, mask)

Supports fp16, fp32 dtypes; flexible hidden dimension, currently templated for 64, 80, and 128.

Causal masking for prompt encoding not yet implemented, this shader is focused at present on full self-attention.

Checklist

Put an x in the boxes that apply.

  • [x] I have read the CONTRIBUTING document
  • [x] I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • [x] I have added tests that prove my fix is effective or that my feature works
  • [ ] I have updated the necessary documentation (if needed)

bpkeene avatar Apr 06 '24 03:04 bpkeene

Marking as draft, currently working through some numerical issues via separate workflow, and will add CPU side bindings + dispatch, test & docs - sharing a current status and will update this PR

bpkeene avatar Apr 06 '24 04:04 bpkeene

Hi folks, Latency - Self Attention SDPA Memory - Self Attention SDPA

Attaching some graphs for measured latency on M3 Max and some estimated memory savings per attention block (empirically observed at several data points, graph here obtained via formulas)

bpkeene avatar May 21 '24 19:05 bpkeene

Some room for improvement on larger sequences re: latency, with a divergence after ~2300 sequence length, though the memory savings exceeds 1GB ~2k, and is approaching 5GB at 4250 sequence length (SD3 8B use case).

All measurements were with batch size 2, heads = 38, hidden dim = 64, and float32 on M3 Max / 48GB.

bpkeene avatar May 21 '24 19:05 bpkeene

@bpkeene left a few minor comments. Could you address? Once updated we can run the tests and get this merged.

awni avatar May 22 '24 16:05 awni

Updated with the requested changes, thank you for the prompt review!

bpkeene avatar May 23 '24 22:05 bpkeene