torchtune
torchtune copied to clipboard
Does torchtune support multi-node training?
Does torchtune support multi-node training? For example, in a SLURM environment?
If so, would it be possible to get an example config?
Hi @tginart, if you just want to get a very, very basic multi-node setup running, it actually shouldn't be too hard. We wrap around torchrun which supports multi-node.
Just curious - what's your use case here? Large models, faster training, more data? Do you already have access to a multi-node setup running w/ SLURM? Or are you considering one?
Hi @joecummings
Current use-case is just faster fine-tuning over larger datasets.
I do have access to multi-node SLURM already, and have trained using other frameworks. For various reasons I've used torchtune recently for some small models on single node and was just wondering if it has multi-node.
What file should I take a look at?
Ah in that case, I'd recommend two things:
- Instead of utilizing
tune run, which we artificially constrain to 1 node for now,tune cpthe recipe and config you want and then launch directly withtorchrun. e.g.torchrun --nnodes 2 --nproc-per-node 8 full_finetune_distributed.py --config llama3_2/3B_full.yaml. This alone should enable you to run fine-tuning over larger datasets w/ FSDP (Just make sure you modify the sharding strategy to be HYBRID. You can do that by modifying thefsdp_kwargshere to include an item for "sharding_strategy": ShardingStrategy. HYBRID_SHARD. - Add a very basic tensor parallel configuration. Right now, we just use FSDP for distributed training which will likely be very slow on multinode b/c it will all-gather everything needed for backprop. Tensor parallel should actually achieve the speed up you need. TP is not quite as basic as step number 1. For simplicity sake, I'd recommend modifying our
shard_modelcode to do something like the following pseudocode:
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import parallelize_module
def shard_model(...):
...
mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp"))
# Parallelize input and output first
parallelize_module(
model,
mesh_2d["tp"],
{
"token_embeddings": PARALLEL_STRATEGY,
"output": PARALLEL_STRATEGY,
}
)
# Iterate over layers and parallelize
for layer in model.layers.items():
layer_plan = {
"attention.wq": PARALLEL_STRATEGY,
"attention.wk": PARALLEL_STRATEGY,
...
}
parallelize_module(model, mesh_2d["tp"], layer_plan)
# Shard the model normally but using the mesh_2d["dp"] object
Torchtitan has a great example of this FSDP + TP work here.
I am using LSF to launch a torchrun --nnodes 2 --nproc-per-node 8 full_finetune_distributed.py full error trace attached
full_trace.txt
8B_full_distributed.txt
I am using LSF to launch a
torchrun --nnodes 2 --nproc-per-node 8 full_finetune_distributed.pyfull error trace attached full_trace.txt 8B_full_distributed.txt
Can you take a look at how TorchTitan launches multinode training? The key part I think that was missing in the above pseudo-code is that you need to specify a rendezvous backend. You can read more about that here
is it something that you guys are thinking to fix/implemenent?
I think this is what you suggested as first try, but still full_trace.txt
torchrun \
--nproc_per_node \$GPU_PER_HOST \
--nnodes \$NUM_HOSTS \
--rdzv-backend c10d \
--rdzv_endpoint \$MASTER_ADDR:\$MASTER_PORT \
src/full_finetune_distributed.py --config \
config_files/8B_full_distributed.yaml \
optimizer_in_bwd=False
training.shard_model(
model=model,
shard_conditions=fsdp_shard_conditions,
cpu_offload=fsdp_cpu_offload,
reshard_after_forward=reshard_after_forward,
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
)
modified shard_model
def shard_model(
model: TransformerDecoder,
shard_conditions: List[Callable[[str, nn.Module], bool]],
*,
cpu_offload: bool,
reshard_after_forward: bool = True,
sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD,
) -> None:
fsdp_kwargs = {
"reshard_after_forward": reshard_after_forward,
"sharding_strategy": sharding_strategy,
}
if cpu_offload:
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()
num_layers_sharded = 0
for n, m in reversed(list(model.named_modules())):
if any([shard_condition(n, m) for shard_condition in shard_conditions]):
fully_shard(m, **fsdp_kwargs)
num_layers_sharded += 1
if num_layers_sharded == 0:
raise ValueError(
"No layer modules were sharded. Please check if shard conditions are working as expected."
)
fully_shard(model, **fsdp_kwargs)
in the end i am managed to get it to run with mpirun, but i get the follwoing error
[rank3]: File "/software/isg/users/fg12/envs/virtualenvs/mlflow-torchtune-31tjdLhK-py3.11/lib/python3.11/site-packages/torch/distributed/_composable/contract.py", line 125, in wrapper
[rank3]: updated = func(inp_module, *args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: TypeError: fully_shard() got an unexpected keyword argument 'sharding_strategy'
which make sense
def fully_shard(
module: Union[nn.Module, List[nn.Module]],
*,
mesh: Optional[DeviceMesh] = None,
reshard_after_forward: Union[bool, int] = True,
mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(),
offload_policy: OffloadPolicy = OffloadPolicy(),
):
@fabiogeraci Ahh! I see in FSDP2, they remap HYBRID_SHARD to reshard_after_forward=True. Give that a go :)
would you mind to explain, please?
i can see
but how do i know which FSPD is used
Yep - there's a great guide here.
HYBRID_SHARD is the same as reshard_after_forward=True which "determines whether parameters are resharded (freed) after forward. If True, then they are re-all-gathered in backward. This trades off saving memory at the cost of extra communication."
All of torchtune and torchtitan uses FSDP2.
how would i switch from 1D mesh to 2d mesh?
mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp"))
This will create a 2 D mesh with 2 nodes and 8 GPUs per node.
- Add a very basic tensor parallel configuration. Right now, we just use FSDP for distributed training which will likely be very slow on multinode b/c it will all-gather everything needed for backprop. Tensor parallel should actually achieve the speed up you need. TP is not quite as basic as step number 1. For simplicity sake, I'd recommend modifying our
shard_modelcode to do something like the following pseudocode:
Hi Joe @joecummings, I was checking torchtitan for multi-node practice as you suggested. I found they seem to use a pure FSDP approach with a 64 GPU setting: here.
So I try to understand the bottleneck of FSDP under multi-node tuning. Is the all-gather happend every layer across all workers so may be bounded by communication? Is that a mistake that torchtitan use pure FSDP with 64 gpus? And I also wonder will tp contribute in this case if I am doing small models like llama7b.
Thanks!
mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp"))This will create a 2 D mesh with 2 nodes and 8 GPUs per node.
I implemented your suggestion, massive improvement in speed on multi nodes multi gpus set up via opemnmpi. I had to tweack shard_model.
May I ask why torchtune does not support multi node multi gpu, out of the box?
Hi Joe @joecummings, I was checking
torchtitanfor multi-node practice as you suggested. I found they seem to use a pure FSDP approach with a 64 GPU setting: here.So I try to understand the bottleneck of FSDP under multi-node tuning. Is the all-gather happend every layer across all workers so may be bounded by communication? Is that a mistake that
torchtitanuse pure FSDP with 64 gpus? And I also wonder will tp contribute in this case if I am doing small models like llama7b.Thanks!
Great questions! I'd really recommend reading through this issue on PyTorch where @awgu discusses why multi-node FSDP is usually slower than single node. The TL;DR is that communication can take longer between different nodes and since FSDP needs all-gather for parameters and reduce-scatter for gradient reduction, this can be a bottleneck in training.
It's probably not a mistake that torchtitan uses only FSDP for that specific config b/c it's the smallest model and my guess is that they're just trying to show that it's possible. If you look at the torchtitan configs for their larger models, you'll see that they use TP.
If you have good interconnect speed between nodes, FSDP will work faster. Regardless, TP + a "hybrid shard" will likely be faster than FSDP in a multi-node setup b/c it's not as communication bound.
Hope this explanation helps!
I implemented your suggestion, massive improvement in speed on multi nodes multi gpus set up via opemnmpi. I had to tweack shard_model.
May I ask why torchtune does not support multi node multi gpu, out of the box?
That's awesome! Would love to take a peak at your code if you want to post a gist so other users can see how you did it.
We actually do plan to support multi node OOTB sometime soon The biggest reason we haven't so far is just due to our own bandwidth constraints. We're a small-ish team and there's lots of new models and techniques coming out every day! We wanted to be sure that we provided a great single node experience before tackling multi node, but like I mentioned, we'll probably have a canonical example in torchtune soon.
openmpi script, launch cli
mpirun \
-np $TOTAL_NUM_GPUS \
-H \$MPI_HOST_STRING \
-x PATH \
-bind-to none \
-map-by slot \
--mca pml ob1 --mca btl ^openib \
--display-allocation \
--display-map \
python3 src/full_finetune_distributed.py \
--config config_files/8B_full_distributed.yaml \
optimizer_in_bwd=False
full_finetune_distributed.py
if int(os.environ.get("NUM_NODES")) > 1:
from torch.distributed._tensor import init_device_mesh
mesh_2d = init_device_mesh("cuda",
mesh_shape=(int(os.environ.get("NUM_NODES")),
int(os.environ['WORLD_SIZE']) // 2),
mesh_dim_names=("dp", "tp"))
else:
mesh_2d = None
training.shard_model(
model=model,
shard_conditions=fsdp_shard_conditions,
cpu_offload=fsdp_cpu_offload,
reshard_after_forward=reshard_after_forward,
mesh=mesh_2d,
)
_distributed.py
def shard_model(
model: TransformerDecoder,
shard_conditions: List[Callable[[str, nn.Module], bool]],
*,
cpu_offload: bool,
reshard_after_forward: bool = True,
mesh: Optional[DeviceMesh] = None # <-- Add this line
) -> None:
if mesh is not None: # <-- Add this line
fsdp_kwargs["mesh"] = mesh # <-- Add this line
would i be able to make PR with this code ;)
We added multi node support in #2301