Crash triggered by mlx::core::fast::ScaledDotProductAttention::eval_gpu in 0.25.3 on M1 and M2
In 0.25.3, we're getting reports of Dia (our new macOS app) users on M1 and M2 devices crashing in mlx::core::fast::ScaledDotProductAttention::eval::gpu.
With Metal shader validation enabled, this is an assertion failure:
validateComputeFunctionArguments:1056: failed assertion
Compute Function(sdpa_vector_float_64_64_floatmask_qt_nc): missing buffer binding at index 11 for mask_head_stride[0].
inside the call to dispatch_threadgroups in sdpa_vector.
With shader validation disabled, this is a later fatal error caught in ErrorHandler.swift:
MLX/ErrorHandler.swift: 332: Fatal error: [metal::Device] Unable to load function sdpa_vector_float_64_64
Function sdpa_vector_float_64_64 was not found in the library at ..... mlx-c/mlx/c/transforms.cpp:73
In 0.25.2 these same M1 and M2 users (with shader validation disabled) were crashing later in mlx::core::metal::check_error
For now we've rolled back to 0.23.1 which has fixed the crashes.
I'm suspecting that the underlying cause might be the change that added float mask to sdpa vector not playing nicely on M1 and M2, and then this change making us crash sooner on 0.25.3.
Are you abele to provide some code that reproduces this? So far I'm not able to do so even when hitting that same kernel. This could also be something that's fixed on main MLX.. though I don't recall this ever being an issue.
I have a repro, well, something very similar:
./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/phi-2-hf-4bit-mlx
with HEAD of mlx-swift-examples. Investigating.
>>> import mlx.core as mx
>>> q = mx.random.normal((1, 32, 8, 80))
>>> r = mx.fast.scaled_dot_product_attention(q, q, q, scale=0.25, mask=None)
>>> r
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: [metal::Device] Unable to load function sdpa_vector_float_80_80
Function sdpa_vector_float_80_80 was not found in the library
with mlx 0.25.0 (python)
This may be fixed in mlx::core already, but not in the mlx-swift build yet. @angeloskath do you know if there is a PR for this to track?
This one is actually not clear that it's fixed / not sure what the issue is here yet. The one that I think @angeloskath fixed in https://github.com/ml-explore/mlx/pull/2246 is this:
import mlx.core as mx
q = mx.random.normal((1, 32, 8, 80))
r = mx.fast.scaled_dot_product_attention(q, q, q, scale=0.25, mask=None)
print(r)
I wasn't able to repro the one in this issue though as both the SDPA vector and matrix kernels support head dimension 64.
That is probably the same one I fixed in #2246 . The check should make sure that either sdpa_full or sdpa_vector can support the use case. Not both.
I work with @alijuma and just tried to update to 0.25.5 but ran into this:
validateComputeFunctionArguments:1056: failed assertion 'ComputeFunction(sdpa_vector_float_64_64_floatmastk_qt_nc): missing buffer binding at index 11 for mask_head_stride[0].'
This is happening somewhere deep in mlx_array_eval when using mlx::core::metal::CommandEncoder via mlx::core::fast::ScaledDotProductAttention::eval_gpu.
verify that this is fixed with #273