diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Optimize VRAM use in textual inversion training

Open Ttl opened this issue 3 years ago • 19 comments

Cast frozen modules to fp16/bf16 when using mixed precision. Add gradient checkpoint command line option.

OOMs before on my 8 GB VRAM GPU. With these changes and using --mixed_precision=fp16 --gradient_checkpointing VRAM use is 6341 MB and the results look good.

Ttl avatar Sep 30 '22 13:09 Ttl

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@isamu-isozaki is this the approach you've been using?

keturn avatar Oct 02 '22 18:10 keturn

@keturn Pretty much! I didn't know about the autocast functionality so I manually moved most of the parts to cpu and cuda. The code is here. One thing I remember was that for 6gb of ram, it'll omm before that accelerator.accumulate part so most models are better moved to the cpu.

isamu-isozaki avatar Oct 02 '22 19:10 isamu-isozaki

Hi! Great pr! @Ttl @keturn. It doesn't fit in 6gb ram as it is now but once I did this

slice_size = unet.config.attention_head_dim // 2
unet.set_attention_slice(slice_size)

it fits. Thanks for this! Now my training will get way better.

isamu-isozaki avatar Oct 03 '22 14:10 isamu-isozaki

@keturn sry on second thought this is way different from my approach but it's way better too!

isamu-isozaki avatar Oct 03 '22 19:10 isamu-isozaki

If you add revision="fp16" to from_pretrained, do you still have to do the conversions to weight_dtype?

keturn avatar Oct 03 '22 19:10 keturn

@patil-suraj can you take a look here?

patrickvonplaten avatar Oct 04 '22 13:10 patrickvonplaten

I'm using locally saved weights and adding revision="fp16" doesn't seem to do anything in that case.

Ttl avatar Oct 04 '22 15:10 Ttl

I tried revision fp16 and got oom for some reason. will double check later today

isamu-isozaki avatar Oct 04 '22 19:10 isamu-isozaki

The idea with casting the weights of non-trained nets is that without it fp32 weights are transferred to vram even when training in fp16. Since they are not trained we don't need to keep fp32 copy of them in vram.

self.training is controlled by train() or eval() call of the module. Since in this case we have set unet to be in eval() without removing the self.training gradient checkpointing is not enabled. Gradient checkpointing is useful in this case since we need to store activations in the unet for backwards pass since it's between our trainable weights and loss calculation. I checked that enabling it saves 1080 MB of memory.

I can maintain a copy of the script that casts non-trained weights to fp16 locally, but it would be nice if the gradient checkpointing changes would be merged. Would you be fine with that?

Ttl avatar Oct 05 '22 11:10 Ttl

Gradient checkpointing is useful in this case since we need to store activations in the unet for backwards pass since it's between our trainable weights and loss calculation. I checked that enabling it saves 1080 MB of memory.

That's a really good observation! Sorry, I rushed the review a bit. In this case keeping the gradient checkpointing changes makes sense, let me try it quickly and get back to you.

Thanks a lot!

patil-suraj avatar Oct 05 '22 12:10 patil-suraj

Also pinging @patrickvonplaten and @anton-l . Are the activations stored even when the grads are disabled for the model ?

patil-suraj avatar Oct 05 '22 12:10 patil-suraj

I added a commit that removes the autocast. It should work with fp32 and bf16 too but I can't test it on my GPU. This PR does have a side effect that it saves fp16 quantized weights of unet and vae since fp32 weights from those were discarded if training in fp16. If you prefer I can remove the training changes and only keep the gradient checkpointing change.

Ttl avatar Oct 05 '22 13:10 Ttl

I think this PR is currently blocked by:

  • gradient_checkpointing being a bit flaky when model is set to train mode
  • that we're not able to pass dropout down

cc @patil-suraj

patrickvonplaten avatar Oct 07 '22 13:10 patrickvonplaten

Any progress on the blockers?

Thomas-MMJ avatar Oct 30 '22 13:10 Thomas-MMJ

@patil-suraj could you have another look here?

patrickvonplaten avatar Nov 02 '22 13:11 patrickvonplaten

cc @patil-suraj here - could you maybe go into the PR if author doesn't reply anymore to not forget it?

patrickvonplaten avatar Nov 16 '22 07:11 patrickvonplaten

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Dec 10 '22 15:12 github-actions[bot]

cc @patil-suraj - could you maybe post some instructions here on how to proceed? Then someone else could pick it up

patrickvonplaten avatar Dec 13 '22 15:12 patrickvonplaten

Sorry for being late again, I've posted instructions in this comment https://github.com/huggingface/diffusers/pull/687#pullrequestreview-1165354545

@Ttl LMK if you are busy, then I'll make the necessary changes and merge :)

patil-suraj avatar Dec 26 '22 17:12 patil-suraj

It's been quite long since I last looked at this code and I haven't used textual inversion much anymore. Feel free to make necessary changes to get it merged if you want to.

Ttl avatar Dec 28 '22 10:12 Ttl

Thanks, will make open a PR then :)

patil-suraj avatar Dec 28 '22 12:12 patil-suraj