mlx-swift icon indicating copy to clipboard operation
mlx-swift copied to clipboard

Crash triggered by mlx::core::fast::ScaledDotProductAttention::eval_gpu in 0.25.3 on M1 and M2

Open alijuma opened this issue 6 months ago • 8 comments

In 0.25.3, we're getting reports of Dia (our new macOS app) users on M1 and M2 devices crashing in mlx::core::fast::ScaledDotProductAttention::eval::gpu.

With Metal shader validation enabled, this is an assertion failure:

validateComputeFunctionArguments:1056: failed assertion Compute Function(sdpa_vector_float_64_64_floatmask_qt_nc): missing buffer binding at index 11 for mask_head_stride[0].

inside the call to dispatch_threadgroups in sdpa_vector.

With shader validation disabled, this is a later fatal error caught in ErrorHandler.swift:

MLX/ErrorHandler.swift: 332: Fatal error: [metal::Device] Unable to load function sdpa_vector_float_64_64

Function sdpa_vector_float_64_64 was not found in the library at ..... mlx-c/mlx/c/transforms.cpp:73

In 0.25.2 these same M1 and M2 users (with shader validation disabled) were crashing later in mlx::core::metal::check_error

For now we've rolled back to 0.23.1 which has fixed the crashes.

I'm suspecting that the underlying cause might be the change that added float mask to sdpa vector not playing nicely on M1 and M2, and then this change making us crash sooner on 0.25.3.

alijuma avatar May 29 '25 18:05 alijuma

Are you abele to provide some code that reproduces this? So far I'm not able to do so even when hitting that same kernel. This could also be something that's fixed on main MLX.. though I don't recall this ever being an issue.

awni avatar May 30 '25 00:05 awni

I have a repro, well, something very similar:

./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/phi-2-hf-4bit-mlx

with HEAD of mlx-swift-examples. Investigating.

davidkoski avatar Jun 06 '25 16:06 davidkoski

>>> import mlx.core as mx
>>> q = mx.random.normal((1, 32, 8, 80))
>>> r = mx.fast.scaled_dot_product_attention(q, q, q, scale=0.25, mask=None)
>>> r
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: [metal::Device] Unable to load function sdpa_vector_float_80_80
Function sdpa_vector_float_80_80 was not found in the library

with mlx 0.25.0 (python)

davidkoski avatar Jun 06 '25 16:06 davidkoski

This may be fixed in mlx::core already, but not in the mlx-swift build yet. @angeloskath do you know if there is a PR for this to track?

davidkoski avatar Jun 06 '25 16:06 davidkoski

This one is actually not clear that it's fixed / not sure what the issue is here yet. The one that I think @angeloskath fixed in https://github.com/ml-explore/mlx/pull/2246 is this:

import mlx.core as mx
q = mx.random.normal((1, 32, 8, 80))
r = mx.fast.scaled_dot_product_attention(q, q, q, scale=0.25, mask=None)
print(r)

I wasn't able to repro the one in this issue though as both the SDPA vector and matrix kernels support head dimension 64.

awni avatar Jun 06 '25 17:06 awni

That is probably the same one I fixed in #2246 . The check should make sure that either sdpa_full or sdpa_vector can support the use case. Not both.

angeloskath avatar Jun 06 '25 20:06 angeloskath

I work with @alijuma and just tried to update to 0.25.5 but ran into this:

validateComputeFunctionArguments:1056: failed assertion 'ComputeFunction(sdpa_vector_float_64_64_floatmastk_qt_nc): missing buffer binding at index 11 for mask_head_stride[0].'

This is happening somewhere deep in mlx_array_eval when using mlx::core::metal::CommandEncoder via mlx::core::fast::ScaledDotProductAttention::eval_gpu.

tristanlabelle avatar Jul 03 '25 14:07 tristanlabelle

verify that this is fixed with #273

davidkoski avatar Sep 26 '25 17:09 davidkoski