mlx-examples icon indicating copy to clipboard operation
mlx-examples copied to clipboard

Distributed inference example

Open angeloskath opened this issue 1 year ago • 18 comments

Simply distributed inference on top of https://github.com/ml-explore/mlx/pull/1270 . Again a draft PR so we can iterate on the design. This communication will be very latency bound (probably impractical) so no need to be particularly excited yet.

angeloskath avatar Jul 15 '24 21:07 angeloskath

Thanks @angeloskath!

This is very timely, as I was looking for such an example for a couple days.

Blaizzy avatar Jul 15 '24 23:07 Blaizzy

Amazing, time to buy a second m2 ultra:p

mzbac avatar Jul 16 '24 00:07 mzbac

@angeloskath, please correct me if I am wrong. By looking at the implementation, it seems like we are sharding vertically. For o_proj, we have to wait for all nodes to complete the forward pass before moving on to the next layer. This would create a bottleneck as the slowest node would slow down the entire process. Would it be better to shard by layers instead?

Edit: I think I understand it now. It makes sense to shard the model across the same hardware using a fast connection that maximizes parallelization. This should be a good fit for the MOE. For dense models, maybe we have to do something similar to exo, shard it over layers and make the inference sequential.

mzbac avatar Jul 18 '24 05:07 mzbac

I think you might be correct here @mzbac!

Tho I would like to also benchmark @angeloskath approach.

I have been researching this topic for weeks to support it on FastMLX. And according to the paper I read and Accelerate docs, layer group sharding is the best approach for distributed inference and training.

But requires every single node / machine to have quick access to model weights/shard on device.

Blaizzy avatar Jul 18 '24 19:07 Blaizzy

They are just different approaches really. Pipelining gives perfect scaling in throughput but not latency.

This means that if you are running evaluations or simply running batch generations, then it is perfect. But it will still take the same amount of time to see the first output. Basically for a single generation and assuming the model fits on one device it doesn't provide any speedup. Another way to say it is that the tokens per second per client are not sped up. The aggregate ones scale pretty much perfectly though.

The approach in this PR is called model parallelism or tensor parallelism. The goal is to reduce latency as well as throughput. However it depends heavily on the latency of the interconnect. So given ethernet this will probably not achieve speedups (we are looking into it). Indeed, we need to communicate 2 * num_layers times to produce a single output. On the other hand with pipelining we only communicate num_shards times per output. So if the communication latency is large pipelining may be a better approach.

angeloskath avatar Jul 18 '24 19:07 angeloskath

They are just different approaches really. Pipelining gives perfect scaling in throughput but not latency.

This means that if you are running evaluations or simply running batch generations, then it is perfect. But it will still take the same amount of time to see the first output. Basically for a single generation and assuming the model fits on one device it doesn't provide any speedup. Another way to say it is that the tokens per second per client are not sped up. The aggregate ones scale pretty much perfectly though.

The approach in this PR is called model parallelism or tensor parallelism. The goal is to reduce latency as well as throughput. However it depends heavily on the latency of the interconnect. So given ethernet this will probably not achieve speedups (we are looking into it). Indeed, we need to communicate 2 * num_layers times to produce a single output. On the other hand with pipelining we only communicate num_shards times per output. So if the communication latency is large pipelining may be a better approach.

@angeloskath, thank you for the detailed explanation. I may try to get another M2 Ultra and test it via the Thunderbolt 4 connection :)

mzbac avatar Jul 19 '24 00:07 mzbac

They are just different approaches really. Pipelining gives perfect scaling in throughput but not latency.

This means that if you are running evaluations or simply running batch generations, then it is perfect. But it will still take the same amount of time to see the first output. Basically for a single generation and assuming the model fits on one device it doesn't provide any speedup. Another way to say it is that the tokens per second per client are not sped up. The aggregate ones scale pretty much perfectly though.

The approach in this PR is called model parallelism or tensor parallelism. The goal is to reduce latency as well as throughput. However it depends heavily on the latency of the interconnect. So given ethernet this will probably not achieve speedups (we are looking into it). Indeed, we need to communicate 2 * num_layers times to produce a single output. On the other hand with pipelining we only communicate num_shards times per output. So if the communication latency is large pipelining may be a better approach.

IMO this is exactly what we need in the long run.

In the short term, the hype is around the 400B llama - but that will fade eventually. Latency optimization is what I think fits with the overall MLX ethos.

fblissjr avatar Jul 21 '24 21:07 fblissjr

I tried clustering one M2 Ultra 192GB with another M2 Ultra 128GB, splitting the weights to 160GB and 67GB (not tensor parallelism) for llama3 405b. I got around 0.3 t/s, but I expected it to be closer to 1 or 2 t/s. I'm not sure if this is related to mlx or some system-level issue.

ps: I tried to run sudo sysctl iogpu.disable_wired_collector=1 but I got the error sysctl: unknown oid 'iogpu.disable_wired_collector'. Maybe that could be a potential issue.

mzbac avatar Jul 30 '24 01:07 mzbac

Was this over WiFi or thunderbolt 4 @mzbac ?

Blaizzy avatar Jul 30 '24 05:07 Blaizzy

Was this over WiFi or thunderbolt 4 @mzbac ?

TB4, I did run some tests and I feel there may be a memory issue when the memory consumption reaches a certain limit by mlx causes the token per second to slow down to 0.x. I am not exactly sure what the issue is, but sharding across deepseek coder v2 4bit was working fine (60+ vram and up to 1xx ram cache).

mzbac avatar Jul 30 '24 05:07 mzbac

Which OS are you on? A couple things that might help:

  1. Restart the machine(s)
  2. Upgrade to Sonoma (OS 15.0)
  3. Set some sysctls:
sudo sysctl iogpu.wired_limit_mb=200000
sudo sysctl iogpu.disable_wired_collector=1

The disable_wired_collector is OS 15.0+. With that combinations I was able to get DeepSeek Coder v2 large (236B params) running pretty fast on a single M2 Ultra.

awni avatar Jul 30 '24 13:07 awni

one M2 Ultra 192GB with another M2 Ultra 128GB, splitting the weights to 160GB and 67GB

Maybe putting more on the 128GB machine will help also. Like 140 and 87 or something.

awni avatar Jul 30 '24 13:07 awni

Which OS are you on? A couple things that might help:

  1. Restart the machine(s)
  2. Upgrade to Sonoma (OS 15.0)
  3. Set some sysctls:
sudo sysctl iogpu.wired_limit_mb=200000
sudo sysctl iogpu.disable_wired_collector=1

The disable_wired_collector is OS 15.0+. With that combinations I was able to get DeepSeek Coder v2 large (236B params) running pretty fast on a single M2 Ultra.

@awni Thanks for the pointers. I will try to upgrade macOS, currently, it's on version 14.5.

mzbac avatar Jul 30 '24 13:07 mzbac

Just to share the update, upgrading to macOs 15.0 helped solve the memory issue, and now I am able to run 405B 4-bit around 3.4 t/s - not bad at all.

https://www.youtube.com/watch?v=_9vP7CS3TI4

mzbac avatar Jul 31 '24 02:07 mzbac

Nice!! Did you keep the sharding you had or rebalance it? I wonder if we could make it faster with a more even balance 🤔 . But 3.4 t/s is a great start. Only faster from here 💪

awni avatar Jul 31 '24 02:07 awni

Nice!! Did you keep the sharding you had or rebalance it? I wonder if we could make it faster with a more even balance 🤔 . But 3.4 t/s is a great start. Only uphill from here 💪

I added a bit more weight to the 128GB machine as you suggested in my layer sharding configuration: Shard server (128gb machine): mlx-sharding-server --model Meta-Llama-3.1-405B-Instruct-4bit-mlx -s 70 -e 126 API server (192gb machine): mlx-sharding-api --model mlx_sharding/Meta-Llama-3.1-405B-Instruct-4bit-mlx -sl 0 -el 70 -s <tb4 ip>:49112 --host 0.0.0.0

mzbac avatar Jul 31 '24 03:07 mzbac

Nice!! Did you keep the sharding you had or rebalance it? I wonder if we could make it faster with a more even balance 🤔 . But 3.4 t/s is a great start. Only uphill from here 💪

I added a bit more weight to the 128GB machine as you suggested in my layer sharding configuration: Shard server (128gb machine): mlx-sharding-server --model Meta-Llama-3.1-405B-Instruct-4bit-mlx -s 70 -e 126 API server (192gb machine): mlx-sharding-api --model mlx_sharding/Meta-Llama-3.1-405B-Instruct-4bit-mlx -sl 0 -el 70 -s <tb4 ip>:49112 --host 0.0.0.0

any update to speed since? got my hands on two 192gbs and getting ready to run some tests over the weekend

DamascusGit avatar Aug 17 '24 04:08 DamascusGit

Nice!! Did you keep the sharding you had or rebalance it? I wonder if we could make it faster with a more even balance 🤔 . But 3.4 t/s is a great start. Only uphill from here 💪

I added a bit more weight to the 128GB machine as you suggested in my layer sharding configuration: Shard server (128gb machine): mlx-sharding-server --model Meta-Llama-3.1-405B-Instruct-4bit-mlx -s 70 -e 126 API server (192gb machine): mlx-sharding-api --model mlx_sharding/Meta-Llama-3.1-405B-Instruct-4bit-mlx -sl 0 -el 70 -s <tb4 ip>:49112 --host 0.0.0.0

any update to speed since? got my hands on two 192gbs and getting ready to run some tests over the weekend

nothing in the mlx-sharding part. I am still waiting for MLX to support pipeline parallelism in MPI. Once that is supported, there may be some performance improvements compared to using gRPC.

mzbac avatar Aug 17 '24 05:08 mzbac

LFG 🚀🔥

Blaizzy avatar Nov 05 '24 21:11 Blaizzy

Closing as this branch is available in mlx-lm: https://github.com/ml-explore/mlx-lm/tree/distributed-layers

awni avatar Mar 17 '25 15:03 awni