DistributedDataParallel error - uninitialized parameters
I'm using Flux quickstart settings with fp8 quantization on 4x3090s. The same settings work on 1x3090.
TRAINING_NUM_PROCESSES=2 export ACCELERATE_EXTRA_ARGS="--multi_gpu"
on line: results = accelerator.prepare(primary_model
RuntimeError: Modules with uninitialized parameters can't be used with DistributedDataParallel. Run a dummy forward pass to correctly initialize the modules
I have tried the DDP argument find_unused_parameters=True and printing modules with requires_grad = True and grad = None, but there aren't any.
oh... well.. actually, i haven't tried multigpu quantised training yet. i assumed it would just work, since we're not really messing with a whole lot other than dtypes. @sayakpaul cc
i am guessing you can't test without quantisation to see?
Run a dummy forward pass to correctly initialize the modules
Did this help or isn't it possible at all?
There is a multiGPU training example with FP8 but it uses ao:
https://github.com/pytorch/ao/blob/main/benchmarks/float8/bench_multi_gpu.py
everything torch does has such a worse interface than everything hugging face does - ao looks like it will work but jesus lord why is it so ugly lol
Run a dummy forward pass to correctly initialize the modules
Did this help or isn't it possible at all?
I tried the following and the same error happened at prepare()
tpacked_noisy_latents = torch.randn(1, 4320, 64,dtype=weight_dtype, device=accelerator.device)
tpooled_projections = torch.randn(1, 768,dtype=weight_dtype, device=accelerator.device)
ttimesteps = torch.randn(1,dtype=weight_dtype, device=accelerator.device)
tguidance = torch.randn(1,dtype=weight_dtype, device=accelerator.device)
tencoder_hidden_states = torch.randn(1, 512, 4096,dtype=weight_dtype, device=accelerator.device)
ttxt_ids = torch.randn(1, 512, 3,dtype=weight_dtype, device=accelerator.device)
timg_ids = torch.randn(1, 4320, 3,dtype=weight_dtype, device=accelerator.device)
with torch.no_grad():
model_pred = transformer(
hidden_states=tpacked_noisy_latents,
timestep=ttimesteps,
guidance=tguidance,
pooled_projections=tpooled_projections,
encoder_hidden_states=tencoder_hidden_states,
txt_ids=ttxt_ids,
img_ids=timg_ids,
joint_attention_kwargs=None,
return_dict=False,
)
transformer = accelerator.prepare(transformer)
Okay. This is helpful. Would you be able to turn the above into a fuller reproducer and provide your accelerate config and launch command?
Will try to look into it tomorrow.
@sayakpaul any luck?
Same error here, could you please provide some possible ideas about multi-gpu quantised training? Maybe I can try to work on it.
this doesn't happen with LORA_TYPE=lycoris and fp8-quanto on 2x 3090
@sayakpaul i got u fam
accelerate launch --multi_gpu test.py
import torch, accelerate
from diffusers import FluxTransformer2DModel
from optimum.quanto import quantize, qint8, freeze
weight_dtype = torch.bfloat16
accelerator = accelerate.Accelerator()
bfl_model = 'black-forest-labs/FLUX.1-dev'
transformer = FluxTransformer2DModel.from_pretrained(bfl_model, torch_dtype=torch.bfloat16, subfolder="transformer")
# you might need 'with accelerator.main_process_first()' if your server lacks system mem
print('quantizing')
quantize(transformer, qint8)
print('freezing')
freeze(transformer)
tpacked_noisy_latents = torch.randn(1, 1024, 64,dtype=weight_dtype, device=accelerator.device)
tpooled_projections = torch.randn(1, 768,dtype=weight_dtype, device=accelerator.device)
ttimesteps = torch.randn(1,dtype=weight_dtype, device=accelerator.device)
tguidance = torch.randn(1,dtype=weight_dtype, device=accelerator.device)
tencoder_hidden_states = torch.randn(1, 512, 4096,dtype=weight_dtype, device=accelerator.device)
ttxt_ids = torch.randn(1, 512, 3,dtype=weight_dtype, device=accelerator.device)
timg_ids = torch.randn(1, 4320, 3,dtype=weight_dtype, device=accelerator.device)
#with torch.no_grad():
# model_pred = transformer(
# hidden_states=tpacked_noisy_latents,
# timestep=ttimesteps,
# guidance=tguidance,
# pooled_projections=tpooled_projections,
# encoder_hidden_states=tencoder_hidden_states,
# txt_ids=ttxt_ids,
# img_ids=timg_ids,
# joint_attention_kwargs=None,
# return_dict=False,
# )
transformer = accelerator.prepare(transformer)
same issue here,
transformer = FluxTransformer2DModel.from_pretrained(bfl_model, torch_dtype=torch.bfloat16, subfolder="transformer")
if accelerator.is_main_process:
print('quantizing')
quantize(transformer, qint8)
print('freezing')
freeze(transformer)
print('waiting..')
accelerator.wait_for_everyone()
for now, DDP works with Lycoris. i will close this and eventually we will receive an upstream fix when there is time for them to focus on it again.