VADER icon indicating copy to clipboard operation
VADER copied to clipboard

How to apply VADER on my own model

Open XuWuLingYu opened this issue 1 year ago • 2 comments

I trained a UNet diffusion model and I want to use VADER to tune a single unet. I followed the training script to assign a peft LoRA model on my Unet. It shows that only 0.7% of the parameters are trainable. Then I added them to the optimizer and trained with a denoising loop. Although only a few parameters are trainable, the CUDA memory still increases rapidly when processing the denoising loop. Have you encountered similar problems? here is some part of my own preparing scripts:

lora_config = peft.LoraConfig(
    r=self.cfg.lora_downdim,
    target_modules=["to_k", "to_v", "to_q"],        # only diffusion_model has these modules
    lora_dropout=0.01,
    lora_alpha=8
)
self.unet.requires_grad_(False)
self.unet = peft.get_peft_model(self.unet, lora_config)
unet_params = []
for _, param in self.unet.named_parameters():
    if param.requires_grad:
        unet_params.append(param)
params_to_optimize = unet_params
self.optimizer = optimizer_class(
    params_to_optimize,
    lr=self.cfg.runner.learning_rate,
    betas=(self.cfg.runner.adam_beta1, self.cfg.runner.adam_beta2),
    weight_decay=self.cfg.runner.adam_weight_decay,
    eps=self.cfg.runner.adam_epsilon,
)

Here is accelerator processing

ddp_modules = (
    self.unet,
    self.optimizer,
    self.train_dataloader,
    self.lr_scheduler,
)
ddp_modules = self.accelerator.prepare(*ddp_modules)
(
    self.unet,
    self.optimizer,
    self.train_dataloader,
    self.lr_scheduler,
) = ddp_modules

In training:

with self.accelerator.accumulate(self.unet):
    with self.accelerator.autocast():
        for i, t in enumerate(timesteps): # timesteps length = 20
               .... # prepare data
               noise_pred=self.unet(...)
               self.scheduler.step(...)

when i =2 , It is reported CUDA out of memories (A800). I wonder if I miss some key points.

XuWuLingYu avatar Nov 26 '24 18:11 XuWuLingYu

@XuWuLingYu Have you solved the problem?

SkylerZheng avatar Jan 21 '25 09:01 SkylerZheng

Hi. you can try implementing Truncated Backpropagation, Subsampling Frames, and Gradient Checkpointing to reduce VRAM usage.

QinOwen avatar Feb 14 '25 22:02 QinOwen