Awni Hannun
Awni Hannun
The code should definitely still run on the GPU. Its possible you are being bottlenecked by communication latency so it looks like the GPU is not used. One thing to...
Yea exactly, so it looks like either communication latency or bandwidth is a bottleneck for you. Did you see this [section in the docs on tuning all reduce](https://ml-explore.github.io/mlx/build/html/usage/distributed.html#tuning-all-reduce)? In general...
Great, keep us posted on how it goes!
That kind of hang is most likely from a deadlock. Without seeing the rest of your code its hard to know. But usually this sort of thing happens when the...
This is the problem: ``` if rank == 0: print(f"Epoch {epoch + 1}, Loss: {mx.distributed.all_sum(epoch_loss) / num_batches}") ``` Every process has to participate in an `all_sum` otherwise there is a...
Ah interesting, that's a bit of a gotcha, sorry I think I told you the wrong thing. Try the following (notice I moved the `item`): ```python loss = mx.distributed.all_sum(epoch_loss) /...
Although you say this hangs as well? That's unexpected.. ``` loss = mx.distributed.all_sum(epoch_loss) / num_batches print(f"Epoch {epoch + 1}, Loss: {loss.item()}") ```
Hmm, I'm not getting a hang there if I use the code like this: ``` loss = mx.distributed.all_sum(epoch_loss) / num_batches loss = loss.item() print(f"Epoch {epoch + 1}, Loss: {loss}") ```...
Just to be sure we are running the same thing. Could you send exactly what you are running (with the print fix) and the command you used to run it...
Also what version of MLX are you using?