PanFusion icon indicating copy to clipboard operation
PanFusion copied to clipboard

[Solved] Memory Issue -- Explicitly freeze the UNet

Open robot0321 opened this issue 7 months ago • 1 comments

Hi Chengzhag,

Thank you very much for kindly sharing the code for your outstanding work.

While experimenting with the code, I encountered a memory issue (which other users have also reported), and this raised a question regarding the UNet training. According to Section 3.5 of your CVPR paper, the SD UNet blocks should be frozen during training. However, when I reviewed the code, it appears that the UNet is also being fine-tuned, resulting in significantly higher memory usage—exceeding 48 GB.

In fact, when I manually froze the UNet, I was able to run the training code using only about 36–37 GB of memory.

Therefore, I would like to clarify: to faithfully reproduce the results presented in the paper, should the UNet be frozen during training, or is there a particular reason for fine-tuning the UNet as well?

Thank you very much for your time and assistance.

robot0321 avatar Apr 28 '25 10:04 robot0321

Update:

We have confirmed that the UNet parameters are not updated by the optimizer, which is consistent with the description in the paper.

The issue was that although the UNet parameters are not optimized, they are not explicitly set with requires_grad_(False), causing PyTorch to still compute their gradients unnecessarily.

To fix the memory issue, please follow the steps below: After the declaration of the UNet (in load_branch() in models/pano/PanoGenerator.py), explicitly set:

def load_branch(self, add_lora, train_lora, add_cn):
    unet = UNet2DConditionModel.from_pretrained(
        self.hparams.model_id, subfolder="unet", torch_dtype=torch.float32, use_safetensors=True)

    for param in unet.parameters():     #### freeze UNet 
        param.requires_grad_(False)

    unet.enable_xformers_memory_efficient_attention()
    unet.enable_gradient_checkpointing()

...

robot0321 avatar Apr 28 '25 10:04 robot0321