exo icon indicating copy to clipboard operation
exo copied to clipboard

[HARD] Support arbitrary tensor parallel splits

Open rltakashige opened this issue 4 months ago β€’ 14 comments

Context: Tensor parallelism allows models to be split among multiple devices by distributing weights among devices. auto_parallel strategies for tensor parallelism use MLX LM's shard_linear and shard_inplace functions to split tensors evenly. These functions assume that the sharding dimensions are exactly divisible by the group size.

Issue: There are several limitations to this approach:

  1. Certain models are unable to be supported by tensor parallelism, particularly quantized models (see GLM Air 4.5 config.json with gs 32 quantization).
  2. Most models cannot be parallelised across 3 nodes, since their intermediate size is not divisible by 3.
  3. This issue is a blocker for smarter tensor parallelism on heterogeneous devices.

A potential start is smart padding.

rltakashige avatar Dec 21 '25 22:12 rltakashige

I would also add a small benchmark test to see how to split the model among the devices instead of splitting it equally like I understand. I would like to work on this problem but with someone else to get to know a little bit more about the code base.

The dimensions don't need to be exactly divisible by the group size if we implement a system where percentages or fraction of the model are distributed. To this day, all I know is that the models are always evenly divided, right?

DevKiDCosmo avatar Dec 22 '25 06:12 DevKiDCosmo

yup! we have serious issues running TP on 3 devices because most model hidden dimensions aren't divisible by it.

We do have a fixed split we need to implement, but our TP sharding code is a lot more involved than our Pipeline code so it's not just an issue of rounding some percentages better.

Evanev7 avatar Dec 22 '25 11:12 Evanev7

Then I don't see a problem. Probably I misunderstood something.

By TP a layer or multiple layer is splitted across devices. But how are they splitted across those devices? I think I might get it? Instead of splitting the model vertically (by layer) it split horizontally (by neurons). But not every layer has the exact same amount of neurons in a layer. Is this the problem? And is every layer split horizontally?

DevKiDCosmo avatar Dec 22 '25 15:12 DevKiDCosmo

Is there also something hybrid like powerful devices are taking layers and less powerful devices doing it by TP?

DevKiDCosmo avatar Dec 22 '25 15:12 DevKiDCosmo

Correct me if I misunderstood your reply - a model is composed of multiple transformer blocks (as well as some other modules we don't care too much about).

In pipeline parallelism, some of the transformer blocks are assigned to one device, some others to a second device, etc. In tensor parallelism, we shard the linear modules inside the transformer blocks. All of the linear modules are split across their intermediate dimension (although which dimension this is depends on the linear module).

The problem isn't that layers have different numbers of parameters, but rather that the tensor parallel code expects computation and communication on tensors of the same size.

Is there also something hybrid like powerful devices are taking layers and less powerful devices doing it by TP?

There is an argument to be made for this (combining pipeline + tensor parallelism). However, as it stands, in practice, the machines capable of doing effective TP have RDMA, and you would gain better performance from simply doing TP on all the machines. While it's outside the scope of this issue, it's a nice thing to explore!

rltakashige avatar Dec 22 '25 16:12 rltakashige

For Pipeline I see by using a 4gb model a rise of 4gb in ram usage which is strange. Why is ist not 1.3gb per device bye 3 device?

DevKiDCosmo avatar Dec 22 '25 17:12 DevKiDCosmo

For Pipeline I see by using a 4gb model a rise of 4gb in ram usage which is strange. Why is ist not 1.3gb per device bye 3 device?

That is strange - is it loading on all 3 devices or just 1?

Evanev7 avatar Dec 22 '25 18:12 Evanev7

I will create a recording tomorrow. If it happens again I will upload.

DevKiDCosmo avatar Dec 22 '25 18:12 DevKiDCosmo

Back to the original. Let a model block have 82 Neurons per layer then every device (3 devices) should get about 27 and one 28. To make this easy we can just calculate the integer division ($\left\lfloor \frac{n}{g} \right\rfloor$) instead of the float division ($\frac{n}{g}$).

DevKiDCosmo avatar Dec 22 '25 19:12 DevKiDCosmo

I still don't understand the problem. Can you @rltakashige create a recording so that others and I can understand?

DevKiDCosmo avatar Dec 22 '25 20:12 DevKiDCosmo

I made a thinking error. Unbalance load can lead to performance loss and to a bottleneck. The smart padding might be intelligent. But instead of giving paddings we can compute the sum or any ops after sole TP by less strained units.

DevKiDCosmo avatar Dec 22 '25 20:12 DevKiDCosmo

Back to the original. Let a model block have 82 Neurons per layer then every device (3 devices) should get about 27 and one 28. To make this easy we can just calculate the integer division ( ⌊ n g βŒ‹ ) instead of the float division ( n g ).

this is the direction of the changes needed - but I'd suggest you get much more familiar with tensor parallel sharding and the structure of these LLMs if you want to tackle this problem, a correct solution could be quite involved.

Evanev7 avatar Dec 22 '25 20:12 Evanev7

Here are two solution proposals.

The first may be much completer than the second. Also the second is much easier to implement than the first but the performance boost should be stronger in the first.

1. Performance based Sharding

Introduce benchmark-driven, uneven tensor sharding where faster devices receive proportionally larger tensor slices, aiming to equalize per-layer execution time. -> Synchron

How it would work

  • Benchmark devices (GEMM throughput, memory bandwidth, attention latency).
  • Assign each device a performance weight.
  • Compute shard sizes as: $shard_i β‰ˆ dim Γ— perf_i / sum(perf)$ Example: $dim = 82, perf = [1.0, 1.0, 1.2] β†’ shards [26, 26, 30]$
  • Faster devices intentionally get more work so all devices finish compute at roughly the same time.
  • Optional slack so faster devices are not blocked early.

Why this seems attractive

  • Enables TP=3 and arbitrary TP sizes
  • Can theoretically handle heterogeneous devices
  • Avoids rejecting models whose dimensions are not divisible

Major challenges

  • TP collectives (all-gather / all-reduce) expect identical tensor shapes
  • Uneven shard sizes break kernel fusion and optimized attention kernels
  • Communication still synchronizes on the slowest participant
  • Backward pass and optimizer sharding become much more complex
  • High implementation and maintenance cost with uncertain performance gains

2. Homogeneous TP Groups + Pipeline Parallelism

Do not force tensor parallelism to handle uneven or heterogeneous splits. Instead:

  • Keep TP only for homogeneous groups where dimensions divide cleanly
  • Handle heterogeneity and odd device counts at the pipeline level

How it would work

  • Benchmark devices.
  • Automatically group similar devices into TP groups (e.g. same GPU type).
  • Use TP only inside those groups (power-of-two sizes).
  • Use pipeline parallelism across groups.
  • Faster groups get more transformer blocks; slower groups get fewer.

Example with 3 devices:

  • 2 similar GPUs β†’ TP group
  • 1 remaining GPU β†’ separate PP stage

Why this works

  • Preserves TP assumptions (equal shapes, fast collectives)
  • Avoids kernel and attention regressions
  • Works with quantized models
  • Naturally supports heterogeneous hardware
  • Much simpler to implement and maintain

Tradeoffs

  • Pipeline bubbles for small batch sizes
  • Slightly higher activation memory
  • Less β€œpure” than full TP, but far more practical

Conclusion This approach aligns with how large systems handle heterogeneity in practice and is likely the safest production path.

DevKiDCosmo avatar Dec 23 '25 06:12 DevKiDCosmo

The first also needs a scheduler to keep track of processes. The second is much more straightforward.

DevKiDCosmo avatar Dec 23 '25 06:12 DevKiDCosmo