mlx icon indicating copy to clipboard operation
mlx copied to clipboard

[BUG] thread issues with evaluation

Open acsweet opened this issue 8 months ago • 5 comments

Describe the bug There seem to be some thread data access issues introduced in the newest version of mlx.

To Reproduce This doesn't produce the exact same error I was experiencing, but it does produce a similar error (I think).

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
import threading
import time
import traceback
import faulthandler

faulthandler.enable()

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.dense1 = nn.Linear(10, 32)
        self.dense2 = nn.Linear(32, 16)
        self.dense3 = nn.Linear(16, 1)
        
    def __call__(self, x):
        x = mx.maximum(0, self.dense1(x))
        x = mx.maximum(0, self.dense2(x))
        return self.dense3(x)

def loss_fn(model, x, y):
    pred = model(x)
    return mx.mean((pred - y) ** 2)

def train_and_convert(thread_id):
    try:
        print(f"Thread {thread_id} starting")
        
        model = SimpleModel()
        optimizer = optim.Adam(learning_rate=0.001)

        x = mx.random.normal((128, 10))
        y = mx.random.normal((128, 1))
        
        for i in range(10):
            print(f"Thread {thread_id}, Iteration {i} starting")
            loss, grads = mx.value_and_grad(loss_fn)(model, x, y)
            
            print(f"Thread {thread_id}, Iteration {i} - evaluating gradients")
            for g in grads.values():
                mx.eval(g)
            
            print(f"Thread {thread_id}, Iteration {i} - updating model")
            optimizer.update(model, grads)
            
            print(f"Thread {thread_id}, Iteration {i} - evaluating loss")
            mx.eval(loss)
            
            print(f"Thread {thread_id}, Iteration {i} - converting loss to numpy")
            # Convert loss to numpy
            np_loss = np.array(loss)
            
            print(f"Thread {thread_id}, Iteration {i} - converting gradients to numpy")
            for name, g in list(grads.items())[:2]:
                print(f"Thread {thread_id}, Iteration {i} - converting gradient {name}")
                np_grad = np.array(g)
                print(f"Thread {thread_id}, Iteration {i}, Grad {name} shape: {np_grad.shape}")
            
            print(f"Thread {thread_id}, Iteration {i}, Loss: {np_loss}")
            
            # small delay to increase chances of thread overlap
            time.sleep(0.01)
    except Exception as e:
        print(f"Thread {thread_id} failed with exception: {e}")
        print(traceback.format_exc())

threads = []
for i in range(3): # Try with more threads
    t = threading.Thread(target=train_and_convert, args=(i,))
    t.daemon = False
    threads.append(t)
    t.start()
    # Small delay between thread starts
    time.sleep(0.05)

try:
    for t in threads:
        t.join()
    print("All threads completed successfully")
except KeyboardInterrupt:
    print("Interrupted by user")
except Exception as e:
    print(f"Exception in main thread: {e}")
    print(traceback.format_exc())

Expected behavior When I run this with mlx 0.23.2 it executes fine, but with 0.24.0 and up it either gives a segmentation fault or python fatal error like:

python(15651,0x171ba3000) malloc: Double free of object 0x14ea0dbb0
python(15651,0x171ba3000) malloc: *** set a breakpoint in malloc_error_break to debug
Fatal Python error: Aborted

I'd expect it to execute with no issues, or if this is expected/unavoidable behavior with the updates from version 0.24.0 onwards.

Desktop (please complete the following information):

  • OS Version: MacOS 15.2
  • Version: 0.23.2 and 0.24.0 up

Additional context This is related to the mlx backend for Keras (https://github.com/keras-team/keras/issues/19571), and I initially ran across this error with a simple model.fit(), it would occur when the progress bar was updating (which involves a cast to a numpy array to display the loss). In that case it also executed to completion with 0.23.2 and returned for 0.24.2 failed with a segmentation fault or an error like this:

-[AGXG16XFamilyCommandBuffer tryCoalescingPreviousComputeCommandEncoderWithConfig:nextEncoderClass:]:1091: failed assertion `A command encoder is already encoding to this command buffer'
Fatal Python error: Aborted

While I was trying to replicate that last error I was able to get the error above. I can keep trying to replicate the last error with a simple code block if that's helpful, but hopefully the above is enough to identify what's happening.

acsweet avatar Apr 11 '25 23:04 acsweet

I think I'll see if I can work on a fix for this.

acsweet avatar Apr 18 '25 22:04 acsweet

Mlx is in general not thread safe. There isn’t an easy fix for this. You can play around with it and see what you think. But in general we’ll probably need to make some design decisions around if and how to support this kind of use case. We can use this issue to discuss further.

awni avatar Apr 18 '25 22:04 awni

If I have an idea should I submit a PR? And we can discuss what you think?

acsweet avatar Apr 20 '25 08:04 acsweet

Yea that's totally fine

awni avatar Apr 20 '25 13:04 awni

Maybe it's a bit late to reply to your issue, but I had a similar error in Rust.

The mlx array is not thread-safe. See: Doc Swift

To workaround your issue, you need to ensure that only one operation is sent to the GPU at a time. The best way to achieve this is by using a mutex to make other threads wait until the current one finishes.

Here’s an example using Python:

from threading import Lock

mutex = Lock()
.....
class SimpleModel(nn.Module):
def train_and_convert(thread_id):
         mutex.acquire()
         mx.eval(loss)
         mutex.release()
......

same for the other eval i did't test but should work

Armanoide avatar Aug 05 '25 15:08 Armanoide