Large ops with multiprocess fail with internal error
Run the following with:
mpirun -n 4 -- python <script>
import mlx.core as mx
import socket
hostname = socket.gethostname()
world = mx.distributed.init()
rank = world.rank()
print(f"Distributed available: {mx.distributed.is_available()}")
print(f"Hostname: {hostname}: {rank}")
DIM = 200000
num_processes = mx.distributed.init().size()
print(f'Hostname: {hostname} num_processes: {num_processes}')
data = mx.zeros((1600, DIM))
w = mx.zeros((512, DIM))
mx.eval(w, data)
for it in range(100):
mx.eval(data @ w.T)
print(it)
Fails on an M2 Ultra with:
libc++abi: terminating due to uncaught exception of type std::runtime_error: [METAL] Command buffer execution failed: Internal Error (0000000e:Internal Error)
Hmm this has nothing to do with distributed processing unfortunately. The following also fails with the same error
for i in {0..3}; do python <path to script> & done; wait
I guess since this is somewhat rare.. and there isn't much if anything we can do in MLX, I will close this for now :.
Some possible work-arounds:
The internal error signifies that the GPU kernel timed out. One way to fix that is to decrease the size of the operations:
For example if you are using LoRA, try
- decreasing the batch size
- decreasing the sequence length
- decrease the model size
Another possible fix is to decrease the number of operations per command buffer, this may work in some cases but may also slow things down. You can do that like so:
MLX_MAX_OPS_PER_BUFFER=1 <training command here>