SimpleTuner icon indicating copy to clipboard operation
SimpleTuner copied to clipboard

DistributedDataParallel error - uninitialized parameters

Open OhioT opened this issue 1 year ago • 12 comments

I'm using Flux quickstart settings with fp8 quantization on 4x3090s. The same settings work on 1x3090.

TRAINING_NUM_PROCESSES=2 export ACCELERATE_EXTRA_ARGS="--multi_gpu"

on line: results = accelerator.prepare(primary_model
RuntimeError: Modules with uninitialized parameters can't be used with DistributedDataParallel. Run a dummy forward pass to correctly initialize the modules

I have tried the DDP argument find_unused_parameters=True and printing modules with requires_grad = True and grad = None, but there aren't any.

OhioT avatar Aug 05 '24 11:08 OhioT

oh... well.. actually, i haven't tried multigpu quantised training yet. i assumed it would just work, since we're not really messing with a whole lot other than dtypes. @sayakpaul cc

bghira avatar Aug 05 '24 11:08 bghira

i am guessing you can't test without quantisation to see?

bghira avatar Aug 05 '24 11:08 bghira

Run a dummy forward pass to correctly initialize the modules

Did this help or isn't it possible at all?

sayakpaul avatar Aug 05 '24 11:08 sayakpaul

There is a multiGPU training example with FP8 but it uses ao: https://github.com/pytorch/ao/blob/main/benchmarks/float8/bench_multi_gpu.py

sayakpaul avatar Aug 05 '24 11:08 sayakpaul

everything torch does has such a worse interface than everything hugging face does - ao looks like it will work but jesus lord why is it so ugly lol

bghira avatar Aug 05 '24 11:08 bghira

Run a dummy forward pass to correctly initialize the modules

Did this help or isn't it possible at all?

I tried the following and the same error happened at prepare()

tpacked_noisy_latents = torch.randn(1, 4320, 64,dtype=weight_dtype, device=accelerator.device)
tpooled_projections = torch.randn(1, 768,dtype=weight_dtype, device=accelerator.device)
ttimesteps = torch.randn(1,dtype=weight_dtype, device=accelerator.device)
tguidance = torch.randn(1,dtype=weight_dtype, device=accelerator.device)
tencoder_hidden_states = torch.randn(1, 512, 4096,dtype=weight_dtype, device=accelerator.device)
ttxt_ids = torch.randn(1, 512, 3,dtype=weight_dtype, device=accelerator.device)
timg_ids = torch.randn(1, 4320, 3,dtype=weight_dtype, device=accelerator.device)

with torch.no_grad():
    model_pred = transformer(
        hidden_states=tpacked_noisy_latents,
        timestep=ttimesteps,
        guidance=tguidance,
        pooled_projections=tpooled_projections,
        encoder_hidden_states=tencoder_hidden_states,
        txt_ids=ttxt_ids,
        img_ids=timg_ids,
        joint_attention_kwargs=None,
        return_dict=False,
    )
transformer = accelerator.prepare(transformer)

OhioT avatar Aug 05 '24 12:08 OhioT

Okay. This is helpful. Would you be able to turn the above into a fuller reproducer and provide your accelerate config and launch command?

Will try to look into it tomorrow.

sayakpaul avatar Aug 05 '24 12:08 sayakpaul

@sayakpaul any luck?

bghira avatar Aug 18 '24 08:08 bghira

Same error here, could you please provide some possible ideas about multi-gpu quantised training? Maybe I can try to work on it.

matabear-wyx avatar Aug 20 '24 07:08 matabear-wyx

this doesn't happen with LORA_TYPE=lycoris and fp8-quanto on 2x 3090

bghira avatar Aug 21 '24 18:08 bghira

@sayakpaul i got u fam

accelerate launch --multi_gpu test.py

import torch, accelerate
from diffusers import FluxTransformer2DModel
from optimum.quanto import quantize, qint8, freeze
weight_dtype = torch.bfloat16

accelerator = accelerate.Accelerator()

bfl_model = 'black-forest-labs/FLUX.1-dev'
transformer = FluxTransformer2DModel.from_pretrained(bfl_model, torch_dtype=torch.bfloat16, subfolder="transformer")

# you might need 'with accelerator.main_process_first()' if your server lacks system mem
print('quantizing')
quantize(transformer, qint8)
print('freezing')
freeze(transformer)

tpacked_noisy_latents = torch.randn(1, 1024, 64,dtype=weight_dtype, device=accelerator.device)
tpooled_projections = torch.randn(1, 768,dtype=weight_dtype, device=accelerator.device)
ttimesteps = torch.randn(1,dtype=weight_dtype, device=accelerator.device)
tguidance = torch.randn(1,dtype=weight_dtype, device=accelerator.device)
tencoder_hidden_states = torch.randn(1, 512, 4096,dtype=weight_dtype, device=accelerator.device)
ttxt_ids = torch.randn(1, 512, 3,dtype=weight_dtype, device=accelerator.device)
timg_ids = torch.randn(1, 4320, 3,dtype=weight_dtype, device=accelerator.device)

#with torch.no_grad():
#    model_pred = transformer(
#        hidden_states=tpacked_noisy_latents,
#        timestep=ttimesteps,
#        guidance=tguidance,
#        pooled_projections=tpooled_projections,
#        encoder_hidden_states=tencoder_hidden_states,
#        txt_ids=ttxt_ids,
#        img_ids=timg_ids,
#        joint_attention_kwargs=None,
#        return_dict=False,
#    )
transformer = accelerator.prepare(transformer)

bghira avatar Aug 21 '24 19:08 bghira

same issue here,

transformer = FluxTransformer2DModel.from_pretrained(bfl_model, torch_dtype=torch.bfloat16, subfolder="transformer")
if accelerator.is_main_process:
        print('quantizing')
        quantize(transformer, qint8)
        print('freezing')
        freeze(transformer)
print('waiting..')
accelerator.wait_for_everyone()

bghira avatar Aug 21 '24 19:08 bghira

for now, DDP works with Lycoris. i will close this and eventually we will receive an upstream fix when there is time for them to focus on it again.

bghira avatar Aug 27 '24 01:08 bghira