mlx icon indicating copy to clipboard operation
mlx copied to clipboard

Distributed layers

Open angeloskath opened this issue 1 year ago • 1 comments

Adds linear layers that allow training and inference of a model sharded across several devices. The main things added are

  • float16/bfloat16 reductions for MPI
  • AllToShardedLinear and its quantized sibling
  • ShardedToAllLinear and its quantized sibling

simply changing linear layers to the above results in a model that works out of the box with distributed inference and training.

I am starting it as a draft so that we can iterate a bit on the design. The negative aspects of the above design are that we have yet another linear layer to think about when implementing LoRA and friends or weird new quantizations for instance. Perhaps it would be better to make the above layers with an internal linear layer so model surgery that swaps linear layers would still work out of the box.

angeloskath avatar Jul 15 '24 21:07 angeloskath

I kind of like this design. I like that it's all quite simple and easy to follow and we have a lot of control over how to shard the model (as in https://github.com/ml-explore/mlx-examples/pull/890). We could possibly find a way to reduce the code needed for adding a new custom linear-like layer.. but the simplicity is nice, I wouldn't want to give that up.

awni avatar Jul 17 '24 15:07 awni

I am marking this ready for review. The main things that are new since I started the branch:

Exposing mx.contiguous. This ensures both that the array is contiguous and that it occupies at most x.size() * x.itemsize() + 16384 bytes. Mainly a contiguous slice is still going to be copied.

shard_linear convenience function and shard_inplace. The first one just creates the appropriate linear layer quantized or not. The second actually shards the parameters in place which allows us to shard any layer and apply the collective operations as we see fit. It is used for instance to shard the single stream transformer blocks in FLUX but only perform one communication (https://github.com/ml-explore/mlx-examples/pull/1325).

The sharding functions now also take a groups argument. This assumes the linear layer is a fused one and splits it according to the groups argument (evenly or percentage wise). I think the argument name may need improving here.

angeloskath avatar Mar 06 '25 23:03 angeloskath

The sharding functions now also take a groups argument. This assumes the linear layer is a fused one and splits it according to the groups argument (evenly or percentage wise)

The purpose there is to allow uneven shardings? I think it would be good to think on a name that is more different from group.

awni avatar Mar 11 '25 19:03 awni

The purpose there is to allow uneven shardings?

Totally agree that we should name it something different. It isn't for uneven shadings in the sense that one node can take 70% of the computation. This isn't supported in this API. It is for weights that comprise several concatenated weights. In this case for the sharded linear to be valid we need to split, shard and concatenate. Otherwise one node will get all the queries and no keys and so on.

angeloskath avatar Mar 11 '25 20:03 angeloskath

Otherwise one node will get all the queries and no keys and so on.

Ah that makes sense now. Some suggestions on alternative names:

  • shards
  • segments
  • sections
  • splits

Maybe it makes sense to prefix sub with one of those like sub_shards?

awni avatar Mar 11 '25 22:03 awni

This should be ready for review again. I added a bunch of tests and changed the API to support a sharding function so we can shard layers in one go.

I don't think a solution like the nn.quantize one is very doable because the functionality changes based on the sharding type. This means that a linear layer maps to (at least) two sharded linear layers. As a result we can't have a to_sharded function cause it depends on the layer holding the linear and so on.

angeloskath avatar Mar 20 '25 23:03 angeloskath

I am merging as is and when we use it a bit more and it settles I will add the layers to the docs (maybe with a dedicated page of their own ... not sure) and also expose shard_linear and shard_inplace to mlx.nn instead of mlx.nn.layers.distributed.

angeloskath avatar Mar 21 '25 20:03 angeloskath