fairseq2
fairseq2 copied to clipboard
`freqs` in the updated Rotary encoder is of type `torch.complex64` which is not supported in DistributedDataParallel
Describe the bug:
The updated RotaryEncoder implementation introduced in this PR changes the dtype of freqs to torch.complex64. However, complex tensors are not supported by pytorch's DistributedDataParallel. Specifically it will error out during _sync_module_states in DDP's __init__.
Describe how to reproduce: Here's a simple script to reproduce the error. Make sure that you're in a GPU-enabled environment.
from fairseq2.models.llama import load_llama_model
import os
import torch
import torch.nn as nn
import torch.distributed as dist
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group(
backend="nccl",
init_method="env://",
rank=rank,
world_size=world_size
)
torch.cuda.set_device(rank)
model = load_llama_model("llama_7b", device=torch.device("cuda"))
ddp_model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
And run with torchrun --nproc_per_node 1 test.py will produce the following error
Traceback (most recent call last):
File "/mnt/fsx-home/chaoweihuang/fairseq2-public-fork/test.py", line 19, in <module>
ddp_model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
File "/fsx-ust/chaoweihuang/miniconda3/envs/seamless-fairseq2/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 676, in __init__
_sync_module_states(
File "/fsx-ust/chaoweihuang/miniconda3/envs/seamless-fairseq2/lib/python3.10/site-packages/torch/distributed/utils.py", line 142, in _sync_module_states
_sync_params_and_buffers(
File "/fsx-ust/chaoweihuang/miniconda3/envs/seamless-fairseq2/lib/python3.10/site-packages/torch/distributed/utils.py", line 160, in _sync_params_and_buffers
dist._broadcast_coalesced(
RuntimeError: Input tensor data type is not supported for NCCL process group: ComplexFloat
A workaround for this is that we can set the _ddp_params_and_buffers_to_ignore attribute in the root module that's going to be wrapped with DDP since we don't really need to sync freqs across ranks:
model._ddp_params_and_buffers_to_ignore = [
f"decoder.layers.{i}.self_attn.pos_encoder.freqs"
for i in range(len(model.decoder.layers))
]
Thanks for the report @chaoweihuang. We have a similar issue in gumbel softmax quantizer as well. As you already found out, the right solution would be to let DDP ignore those non-persistent buffers. Let me think of a clean way for that. The last time I checked, the DDP API for that was still experimental.