xla icon indicating copy to clipboard operation
xla copied to clipboard

SPMD multi core implementation

Open mfatih7 opened this issue 1 year ago • 2 comments

Hello

Without SPMD, we can train 8 duplicates of a model on 8 TPU cores on the same v3 device and optimize model weights concurrently.

With SPMD, is it possible to place 4 copies of a bigger model on 8 cores on the same v3 device and train them concurrently?

If possible is there any example?

mfatih7 avatar Jan 25 '24 18:01 mfatih7

Hey @mfatih7, it sounds like your use case is to group the devices in pairs and train data parallel across those groups, correct?

This is achievable with SPMD. As an example for FSDP sharding within the groups, you can define your mesh like: mesh = xs.Mesh(range(8), (4, 2), ('replica', 'fsdp')) and shard your model parameters across the fsdp axis and your inputs across all devices, e.g.

# 8 devices, 4 replicas, 2-way FSDP within each replica.
mesh = xs.Mesh(range(8), (4, 2), ('replica', 'fsdp'))

# Shard the model parameters FSDP across the `fsdp` axis.
# The parameters will be replicated along all unspecified mesh axes (i.e. the `replica` axis in this case).
xs.mark_sharding(model.weight, mesh, ('fsdp', None))

# Shard the inputs' batch dimension along all mesh axes.
xs.mark_sharding(inputs, mesh, (('replica', 'fsdp'), None))

jonb377 avatar Feb 07 '24 23:02 jonb377

It would also be worth checking out the FSDPv2 wrapper if you just want to train a bigger model using all devices: https://github.com/pytorch/xla/issues/6379

jonb377 avatar Feb 07 '24 23:02 jonb377

Hey @mfatih7, it sounds like your use case is to group the devices in pairs and train data parallel across those groups, correct?

This is achievable with SPMD. As an example for FSDP sharding within the groups, you can define your mesh like: mesh = xs.Mesh(range(8), (4, 2), ('replica', 'fsdp')) and shard your model parameters across the fsdp axis and your inputs across all devices, e.g.

# 8 devices, 4 replicas, 2-way FSDP within each replica.
mesh = xs.Mesh(range(8), (4, 2), ('replica', 'fsdp'))

# Shard the model parameters FSDP across the `fsdp` axis.
# The parameters will be replicated along all unspecified mesh axes (i.e. the `replica` axis in this case).
xs.mark_sharding(model.weight, mesh, ('fsdp', None))

# Shard the inputs' batch dimension along all mesh axes.
xs.mark_sharding(inputs, mesh, (('replica', 'fsdp'), None))

To fully run 4 copies of model (2-way model parallel), you would need to shard your input on replica axis (4-way data parallel). Otherwise, the model (or the data) will be all-gathered during the computation. Note that, though, "With SPMD, is it possible to place 4 copies of a bigger model on 8 cores on the same v3 device and train them concurrently?" you will require an extra all-reduce within the group (pair of devices) -- SPMD does it for you. cc @jonb377

I am closing this issue, @mfatih7

yeounoh avatar Mar 16 '24 00:03 yeounoh