exo icon indicating copy to clipboard operation
exo copied to clipboard

[HARD] Support arbitrary tensor parallel splits

Open rltakashige opened this issue 4 days ago • 14 comments

Context: Tensor parallelism allows models to be split among multiple devices by distributing weights among devices. auto_parallel strategies for tensor parallelism use MLX LM's shard_linear and shard_inplace functions to split tensors evenly. These functions assume that the sharding dimensions are exactly divisible by the group size.

Issue: There are several limitations to this approach:

  1. Certain models are unable to be supported by tensor parallelism, particularly quantized models (see GLM Air 4.5 config.json with gs 32 quantization).
  2. Most models cannot be parallelised across 3 nodes, since their intermediate size is not divisible by 3.
  3. This issue is a blocker for smarter tensor parallelism on heterogeneous devices.

A potential start is smart padding.

rltakashige avatar Dec 21 '25 22:12 rltakashige