Mismatch in shape When vae.enable_gradient_checkpointing()
System Info / 系統信息
I decode latents by vae. And i enable gradient checkpoint in vae.
When loss.backward(), I meet the error:
Exception has occurred: RuntimeError
Mismatch in shape: grad_output[0] has a shape of torch.Size([1, 256, 1, 136, 240]) and output[0] has a shape of torch.Size([1, 256, 2, 136, 240]).
File "/high_perf_store3/world-model/ailab_vision/wangzepeng5/code/v1_5_pilot_weather_transfer/paper_codes/test.py", line 224, in
My codes bellow: `def load_vae(pretrained_model_name_or_path): vae = AutoencoderKLCogVideoX.from_pretrained( pretrained_model_name_or_path, subfolder="vae", )
return vae
def decode_latents(latents: torch.Tensor, vae) -> torch.Tensor: # with torch.no_grad(): latents = latents.to(vae.dtype).to(vae.device)
latents = 1 / vae.config.scaling_factor * latents
frames = vae.decode(latents).sample
frames = (frames / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
return frames
if name == "main": device = "cuda" weight_dtype = torch.float32 pretrained_model_name_or_path = "pretrain_weights/CogVideoX-Fun-V1.5-5b-InP" clip_model = "pretrain_weights/clip/pretrain_weights/ViT-B-32.pt" vgg_model_path = "pretrain_weights/vgg/vgg19-dcbb9e9d.pth" batch = 1 nf = 5 h = 272 w = 480
vae = load_vae(pretrained_model_name_or_path).to(device).to(weight_dtype)
vae.enable_gradient_checkpointing()
loss_computer = LossComputer(
device=device,
clip_model=clip_model,
vgg_model_path=vgg_model_path,
content_mse_noise=False,
content_mse=False,
content_contrastive=False,
style_clip_direction_global=False,
style_clip_direction_patch=False,
style_clip_align_global=True,
time_continuous=False,
)
latents = torch.randn([batch, 16, ((nf - 1) // 4)+1, h//8, w//8]).to(device).to(weight_dtype).requires_grad_(True)
gen_pixel_values = decode_latents(latents, vae)
# gen_pixel_values = simple_model(latents)
tgt_text = ["driving in the daytime"] * batch
loss_total = loss_computer(
gen_pixel_values=gen_pixel_values,
og_pixel_values=None,
tgt_text=tgt_text,
src_text=None,
tgt_noise=None,
pred_noise=None,
)
loss_total.backward()`
I think the error is when i use gradient checkpoint, "conv cache" in vae has been deleted? Anyone can help me?
Information / 问题信息
- [ ] The official example scripts / 官方的示例脚本
- [x] My own modified scripts / 我自己修改的脚本和任务
Reproduction / 复现过程
As above.
Expected behavior / 期待表现
Fix the bug when apply grad in vae.
Have you solved this problem? When I decoded latents back to pixel space for loss, not only did the video memory usage become very large, but this error also occurred.
Have you solved this problem? When I decoded latents back to pixel space for loss, not only did the video memory usage become very large, but this error also occurred.
No, I think the error maybe related to the "conv cache" in VAE.
This problem can be solved by using the VAE implementation in the latest diffusers code.