xla
xla copied to clipboard
SPMD multi core implementation
Hello
Without SPMD, we can train 8 duplicates of a model on 8 TPU cores on the same v3 device and optimize model weights concurrently.
With SPMD, is it possible to place 4 copies of a bigger model on 8 cores on the same v3 device and train them concurrently?
If possible is there any example?
Hey @mfatih7, it sounds like your use case is to group the devices in pairs and train data parallel across those groups, correct?
This is achievable with SPMD. As an example for FSDP sharding within the groups, you can define your mesh like: mesh = xs.Mesh(range(8), (4, 2), ('replica', 'fsdp'))
and shard your model parameters across the fsdp
axis and your inputs across all devices, e.g.
# 8 devices, 4 replicas, 2-way FSDP within each replica.
mesh = xs.Mesh(range(8), (4, 2), ('replica', 'fsdp'))
# Shard the model parameters FSDP across the `fsdp` axis.
# The parameters will be replicated along all unspecified mesh axes (i.e. the `replica` axis in this case).
xs.mark_sharding(model.weight, mesh, ('fsdp', None))
# Shard the inputs' batch dimension along all mesh axes.
xs.mark_sharding(inputs, mesh, (('replica', 'fsdp'), None))
It would also be worth checking out the FSDPv2 wrapper if you just want to train a bigger model using all devices: https://github.com/pytorch/xla/issues/6379
Hey @mfatih7, it sounds like your use case is to group the devices in pairs and train data parallel across those groups, correct?
This is achievable with SPMD. As an example for FSDP sharding within the groups, you can define your mesh like:
mesh = xs.Mesh(range(8), (4, 2), ('replica', 'fsdp'))
and shard your model parameters across thefsdp
axis and your inputs across all devices, e.g.# 8 devices, 4 replicas, 2-way FSDP within each replica. mesh = xs.Mesh(range(8), (4, 2), ('replica', 'fsdp')) # Shard the model parameters FSDP across the `fsdp` axis. # The parameters will be replicated along all unspecified mesh axes (i.e. the `replica` axis in this case). xs.mark_sharding(model.weight, mesh, ('fsdp', None)) # Shard the inputs' batch dimension along all mesh axes. xs.mark_sharding(inputs, mesh, (('replica', 'fsdp'), None))
To fully run 4 copies of model (2-way model parallel), you would need to shard your input on replica
axis (4-way data parallel). Otherwise, the model (or the data) will be all-gathered during the computation. Note that, though, "With SPMD, is it possible to place 4 copies of a bigger model on 8 cores on the same v3 device and train them concurrently?"
you will require an extra all-reduce within the group (pair of devices) -- SPMD does it for you.
cc @jonb377
I am closing this issue, @mfatih7