PanFusion
PanFusion copied to clipboard
[Solved] Memory Issue -- Explicitly freeze the UNet
Hi Chengzhag,
Thank you very much for kindly sharing the code for your outstanding work.
While experimenting with the code, I encountered a memory issue (which other users have also reported), and this raised a question regarding the UNet training. According to Section 3.5 of your CVPR paper, the SD UNet blocks should be frozen during training. However, when I reviewed the code, it appears that the UNet is also being fine-tuned, resulting in significantly higher memory usage—exceeding 48 GB.
In fact, when I manually froze the UNet, I was able to run the training code using only about 36–37 GB of memory.
Therefore, I would like to clarify: to faithfully reproduce the results presented in the paper, should the UNet be frozen during training, or is there a particular reason for fine-tuning the UNet as well?
Thank you very much for your time and assistance.
Update:
We have confirmed that the UNet parameters are not updated by the optimizer, which is consistent with the description in the paper.
The issue was that although the UNet parameters are not optimized, they are not explicitly set with requires_grad_(False), causing PyTorch to still compute their gradients unnecessarily.
To fix the memory issue, please follow the steps below:
After the declaration of the UNet (in load_branch() in models/pano/PanoGenerator.py), explicitly set:
def load_branch(self, add_lora, train_lora, add_cn):
unet = UNet2DConditionModel.from_pretrained(
self.hparams.model_id, subfolder="unet", torch_dtype=torch.float32, use_safetensors=True)
for param in unet.parameters(): #### freeze UNet
param.requires_grad_(False)
unet.enable_xformers_memory_efficient_attention()
unet.enable_gradient_checkpointing()
...