mlx
mlx copied to clipboard
[Enhancement Request] SDFA does not support head dimension of size 192 (capped at 128)
Describe the bug SDFA currently does not support head dimension outside of 64, 96, and 128.
To Reproduce Fused attention falls back to regular operation when head dimension is not in (64, 96, 128) https://github.com/ml-explore/mlx/blob/main/mlx/fast.cpp#L644-L645
Request SDFA to support non-common head dimensions (still multiple of 32)
Desktop (please complete the following information):
- OS Version: MacOS 15.2
- MLX Version 0.20.0
Additional context Awni's suggestion: «generalize it so that any even head dim or maybe multiple of 32 is supported»