exo
exo copied to clipboard
[HARD] Support arbitrary tensor parallel splits
Context:
Tensor parallelism allows models to be split among multiple devices by distributing weights among devices.
auto_parallel strategies for tensor parallelism use MLX LM's shard_linear and shard_inplace functions to split tensors evenly. These functions assume that the sharding dimensions are exactly divisible by the group size.
Issue: There are several limitations to this approach:
- Certain models are unable to be supported by tensor parallelism, particularly quantized models (see GLM Air 4.5 config.json with gs 32 quantization).
- Most models cannot be parallelised across 3 nodes, since their intermediate size is not divisible by 3.
- This issue is a blocker for smarter tensor parallelism on heterogeneous devices.
A potential start is smart padding.