accelerate icon indicating copy to clipboard operation
accelerate copied to clipboard

How to handle gradient accumulation with multiple models ?

Open patil-suraj opened this issue 2 years ago • 9 comments

To do gradient accumulation with accelerate we wrap the model in accelerator.accumulate context. But what would be the right way to achieve this when multiple models are involved ?

For example, when training latent diffusion models we have 3 separate models, a vae, text encoder and a unet, as you can see in this script. Of which only the text_encoder is being trained (but could also train others as well).

The obvious way to do this would be to create a wrapper model, but curious to know if this can be achieved without using the wrapper model.

cc @muellerzr

patil-suraj avatar Aug 31 '22 12:08 patil-suraj

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Sep 30 '22 15:09 github-actions[bot]

Re-opening this issue again. For doing grad accum with accelerator.accumulate with two models (both are being trained) can we use two context managers like this

with accumulate(model1) as _, with accumulate(model2) as _:
     training_step()

patil-suraj avatar Oct 17 '22 09:10 patil-suraj

Re-opening this issue again. For doing grad accum with accelerator.accumulate with two models (both are being trained) can we use two context managers like this

with accumulate(model1) as _, with accumulate(model2) as _:
     training_step()

Does this currently work? Or is that a feature request, meaning it currently wouldn't work, but would work in the future? Sorry, I got a bit confused by the feature request tag and your comment.

Lime-Cakes avatar Nov 10 '22 09:11 Lime-Cakes

Very interested in this. I'm training two models at once and can only use batch sizes of less than 5 on my machine... So gradient accumulation would be great

pfeatherstone avatar Dec 05 '22 13:12 pfeatherstone

I "solved" it by creating one Accelerator per model. If you use only one and register the models via accelerator.prepare(model1, ..., modelN) at least one of the models is not learning anything. This might be a bug.

LvanderGoten avatar Jan 19 '23 10:01 LvanderGoten

As I think, in this case writing accumulation by yourself maybe more flexible. Accelerator.accumulate() is not necessary. Just write code like:

loss = loss / gradient_accumulation_steps
accelerator.backward(loss)
if (index+1) % gradient_accumulation_steps == 0:
    optimizer.step()
    scheduler.step()
    optimizer.zero_grad()

meng-wenlong avatar Feb 21 '23 09:02 meng-wenlong

This is pretty simple actually, just use "with accelerator.accumulate(model1), accelerator.accumulate(model2): " this is the mechanism of "with", the following code will be in this two contexts, so just simply put them together with comma.

PrinceRay7 avatar Mar 11 '23 12:03 PrinceRay7

This is pretty simple actually, just use "with accelerator.accumulate(model1), accelerator.accumulate(model2): " this is the mechanism of "with", the following code will be in this two contexts, so just simply put them together with comma.

This apparently is not working. I printed AdamW statistics of parameter groups from different models, and one of them will go out of sync between GPUs with this setup, which from my point of view should not happen in DDP.

Quoting from torch forum:

That said, if a single call to backward involves gradient accumulation for more than 1 DDP wrapped module, then you’ll have to use a different process group for each of them to avoid interference.

Format:

Sync [GPU ID]
Min Max Mean of parameter group exp_sq_avg.sqrt()
...

image

After wrapping the models in a SuperModel module it no longer goes async.

image

eliphatfs avatar Jun 02 '23 23:06 eliphatfs

TL; DR: don't do gradient accumulation with multiple models. Wrap them in a wrapper model and do accelerator stuff with it. Move relevant forward logic inside the wrapper model.

Edit: creating an accelerator for each model as @LvanderGoten suggests could also work. Personally I prefer the wrapper model.

eliphatfs avatar Jun 02 '23 23:06 eliphatfs

@eliphatfs would you please show your solution (wrapping models together) in a sudo code? I am working on training controlnet + SD modules together.

mahdip72 avatar Jun 29 '23 17:06 mahdip72

Basically, if you have this in your main training loop:

states = text_encoder(input_ids)
pred = unet(noisy_latents, states, timesteps)
loss = F.mse_loss(pred, targets)
# now loss.backward() will corrupt gradients if you are using accumulation on multi-gpu

Change it into:

class SuperModel(nn.Module):
    def __init__(self, unet: UNet2DConditionModel, text_encoder: nn.Module) -> None:
        super().__init__()
        self.unet = unet
        self.text_encoder = text_encoder
  def forward(self, input_ids, noisy_latents, timesteps):
        states = text_encoder(input_ids)
        return unet(noisy_latents, states, timesteps)

When constructing models, construct a SuperModel after you do with the modules. When accelerator.prepare, only do it on the SuperModel. Same with optimizer and clip grad norm (or may be these are not important). And in the main loop replace the two lines with a single call to SuperModel forward:

pred = supermodel(noisy_latents, states, timesteps)
loss = F.mse_loss(pred, targets)

You may also need to change the final saving:

        supermodel: SuperModel = accelerator.unwrap_model(supermodel)
        supermodel.text_encoder.save_pretrained(os.path.join(args.output_dir, 'text_encoder'))
        supermodel.unet.save_pretrained(os.path.join(args.output_dir, 'unet'))

I have not yet a good idea how to do with LoRA layers yet. It seems that LoRA layers on multiple modules are causing more problems since only AttnProcLayers get prepare-d.

eliphatfs avatar Jun 30 '23 02:06 eliphatfs

I "solved" it by creating one Accelerator per model. If you use only one and register the models via accelerator.prepare(model1, ..., modelN) at least one of the models is not learning anything. This might be a bug.

You mean create two accelerator objects and use nested accumulate for training loop?

with accel_1.accumulate(model1):
     with accel_2.accumulate(model2):
          training_steps

mahdip72 avatar Jun 30 '23 12:06 mahdip72

Does it now support gradient accumulation for multiple models?

cfeng16 avatar Aug 20 '23 22:08 cfeng16

Does it now support gradient accumulation for multiple models?

I think #1708 should fix it according to comment.

hkunzhe avatar Aug 23 '23 03:08 hkunzhe

Can we use gradient accumulation for multiple models in distributed training?

cfeng16 avatar Oct 29 '23 02:10 cfeng16

Yes, just wrap them all in the accumulate function as shown in the earlier PR linked

muellerzr avatar Oct 29 '23 02:10 muellerzr

As I think, in this case writing accumulation by yourself maybe more flexible. Accelerator.accumulate() is not necessary. Just write code like:

loss = loss / gradient_accumulation_steps
accelerator.backward(loss)
if (index+1) % gradient_accumulation_steps == 0:
    optimizer.step()
    scheduler.step()
    optimizer.zero_grad()

don't forget to delete with accelerator.accumulate(unet): and Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, guys

Chao0511 avatar Jan 04 '24 14:01 Chao0511