candle
candle copied to clipboard
Integrate MLX SDPA kernels with mask
This PR integrates kernel developments from: https://github.com/ml-explore/mlx/pull/1924.
Specifically, our candle_nn::ops::sdpa function now dispatches to optimized implementations for with and without prompts. There is also an option for causal masking, removing the necessity for mask materialization.
Overall, this means that we can fuse the attention operation on Metal for prompt and decode phases!
I will update this PR further with benchmarks, but it is tested and working in my fork through mistral.rs.