candle icon indicating copy to clipboard operation
candle copied to clipboard

Integrate MLX SDPA kernels with mask

Open EricLBuehler opened this issue 8 months ago • 0 comments

This PR integrates kernel developments from: https://github.com/ml-explore/mlx/pull/1924.

Specifically, our candle_nn::ops::sdpa function now dispatches to optimized implementations for with and without prompts. There is also an option for causal masking, removing the necessity for mask materialization.

Overall, this means that we can fuse the attention operation on Metal for prompt and decode phases!

I will update this PR further with benchmarks, but it is tested and working in my fork through mistral.rs.

EricLBuehler avatar Mar 22 '25 01:03 EricLBuehler