diffusers
diffusers copied to clipboard
Optimize VRAM use in textual inversion training
Cast frozen modules to fp16/bf16 when using mixed precision. Add gradient checkpoint command line option.
OOMs before on my 8 GB VRAM GPU. With these changes and using --mixed_precision=fp16 --gradient_checkpointing VRAM use is 6341 MB and the results look good.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.
@isamu-isozaki is this the approach you've been using?
@keturn Pretty much! I didn't know about the autocast functionality so I manually moved most of the parts to cpu and cuda. The code is here. One thing I remember was that for 6gb of ram, it'll omm before that accelerator.accumulate part so most models are better moved to the cpu.
Hi! Great pr! @Ttl @keturn. It doesn't fit in 6gb ram as it is now but once I did this
slice_size = unet.config.attention_head_dim // 2
unet.set_attention_slice(slice_size)
it fits. Thanks for this! Now my training will get way better.
@keturn sry on second thought this is way different from my approach but it's way better too!
If you add revision="fp16" to from_pretrained, do you still have to do the conversions to weight_dtype?
@patil-suraj can you take a look here?
I'm using locally saved weights and adding revision="fp16" doesn't seem to do anything in that case.
I tried revision fp16 and got oom for some reason. will double check later today
The idea with casting the weights of non-trained nets is that without it fp32 weights are transferred to vram even when training in fp16. Since they are not trained we don't need to keep fp32 copy of them in vram.
self.training is controlled by train() or eval() call of the module. Since in this case we have set unet to be in eval() without removing the self.training gradient checkpointing is not enabled. Gradient checkpointing is useful in this case since we need to store activations in the unet for backwards pass since it's between our trainable weights and loss calculation. I checked that enabling it saves 1080 MB of memory.
I can maintain a copy of the script that casts non-trained weights to fp16 locally, but it would be nice if the gradient checkpointing changes would be merged. Would you be fine with that?
Gradient checkpointing is useful in this case since we need to store activations in the unet for backwards pass since it's between our trainable weights and loss calculation. I checked that enabling it saves 1080 MB of memory.
That's a really good observation! Sorry, I rushed the review a bit. In this case keeping the gradient checkpointing changes makes sense, let me try it quickly and get back to you.
Thanks a lot!
Also pinging @patrickvonplaten and @anton-l . Are the activations stored even when the grads are disabled for the model ?
I added a commit that removes the autocast. It should work with fp32 and bf16 too but I can't test it on my GPU. This PR does have a side effect that it saves fp16 quantized weights of unet and vae since fp32 weights from those were discarded if training in fp16. If you prefer I can remove the training changes and only keep the gradient checkpointing change.
I think this PR is currently blocked by:
gradient_checkpointingbeing a bit flaky when model is set totrainmode- that we're not able to pass dropout down
cc @patil-suraj
Any progress on the blockers?
@patil-suraj could you have another look here?
cc @patil-suraj here - could you maybe go into the PR if author doesn't reply anymore to not forget it?
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
cc @patil-suraj - could you maybe post some instructions here on how to proceed? Then someone else could pick it up
Sorry for being late again, I've posted instructions in this comment https://github.com/huggingface/diffusers/pull/687#pullrequestreview-1165354545
@Ttl LMK if you are busy, then I'll make the necessary changes and merge :)
It's been quite long since I last looked at this code and I haven't used textual inversion much anymore. Feel free to make necessary changes to get it merged if you want to.
Thanks, will make open a PR then :)