DiffSynth-Studio icon indicating copy to clipboard operation
DiffSynth-Studio copied to clipboard

使用deepspeed zero_stage 3训练,遇到张量shape不一致的问题该怎么解决?

Open RockyLQ1 opened this issue 5 months ago • 16 comments

同样的工程使用deepspeed zero_stage 2训练,不会出现shape不一致的问题,但会出现显存OOM。 推测是zero_stage 3分片后导致张量shape改变,求助解决方案

RockyLQ1 avatar Jul 14 '25 12:07 RockyLQ1

class CausalConv3d(nn.Conv3d) 报错:

Image

RockyLQ1 avatar Jul 14 '25 12:07 RockyLQ1

@Artiprocher 期待你的解答!

RockyLQ1 avatar Jul 14 '25 12:07 RockyLQ1

出现一样的问题,蹲解决方案

mrj-taffy avatar Jul 17 '25 10:07 mrj-taffy

我也遇到了这个同样的问题,在deepspeed zero3阶段,只用zero2 又只能17帧

zhangquanwei962 avatar Jul 23 '25 06:07 zhangquanwei962

any solution to this ?

pooyafayyazs avatar Jul 29 '25 05:07 pooyafayyazs

同问,在a6000*2的配置上使用deepspeed zero3会在 DiffSynth-Studio/diffsynth/models/wan_video_vae.py line 58: x = torch.cat([cache_x, x], dim=2) 报错 RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 640 but got size 160 for tensor number 1 in the list.

VioletOrz avatar Aug 05 '25 02:08 VioletOrz

I suffered the same problem.

================================= [rank0]: File "/workspace/DiffSynth-Studio/diffsynth/models/wan_video_vae.py", line 48, in forward [rank0]: x = torch.cat([cache_x, x], dim=2) [rank0]: RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 384 but got size 96 for tensor number 1 in the list.

WANGCHAO0116 avatar Aug 12 '25 02:08 WANGCHAO0116

蹲,遇到了同样的问题

DuanCB avatar Aug 15 '25 08:08 DuanCB

MaxwellDing avatar Sep 23 '25 03:09 MaxwellDing

I suffered the same problem

polestarss avatar Sep 23 '25 07:09 polestarss

蹲,同样的问题,zero2能跑,zero3有问题

Vickeyhw avatar Sep 24 '25 09:09 Vickeyhw

问题原因在于vae的feat_cache和zero3冲突,但是feat_cache的创建位置太零碎,还有各种shape操作,巨难改。。。

我的解决方案,不优雅:

模型加载后把vae挪到外面WanTrainingModule里,pipe.vae置空,然后accelerate.prepare传入model.pipe,而不是model,另外需要修改所有用到vae的PipeUnit,把vae作为参数传进去,例如:

Image

vae就成功脱离zero3了!

注:得记得手动把vae挪到正确的device上

MaxwellDing avatar Sep 24 '25 09:09 MaxwellDing

Hi @MaxwellDing Thank you for sharing your solution. I followed your solution with zero_3 and use_gradient_checkpointing enabled, encountering the following error. Do you have any insights to share? Thank you very much!

[rank0]: raise CheckpointError( [rank0]: torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass. [rank0]: tensor at position 13: [rank0]: saved metadata: {'shape': torch.Size([1536]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: tensor at position 23: [rank0]: saved metadata: {'shape': torch.Size([1536]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: tensor at position 52: [rank0]: saved metadata: {'shape': torch.Size([1536]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: tensor at position 62: [rank0]: saved metadata: {'shape': torch.Size([1536]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: tensor at position 91: [rank0]: saved metadata: {'shape': torch.Size([1536]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: tensor at position 101: [rank0]: saved metadata: {'shape': torch.Size([1536]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: tensor at position 130: [rank0]: saved metadata: {'shape': torch.Size([1536]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: tensor at position 140: [rank0]: saved metadata: {'shape': torch.Size([1536]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: recomputed metadata: {'

zoezhou1999 avatar Sep 26 '25 20:09 zoezhou1999

@zoezhou1999 ,Hi there, I haven't encountered that error. However, if the VAE is set to be trainable, my solution might not work.

MaxwellDing avatar Sep 28 '25 06:09 MaxwellDing

问题原因在于vae的feat_cache和zero3冲突,但是feat_cache的创建位置太零碎,还有各种shape操作,巨难改。。。

我的解决方案,不优雅:

模型加载后把vae挪到外面WanTrainingModule里,pipe.vae置空,然后accelerate.prepare传入model.pipe,而不是model,另外需要修改所有用到vae的PipeUnit,把vae作为参数传进去,例如:

Image vae就成功脱离zero3了!

注:得记得手动把vae挪到正确的device上

所以对于全量训练来说,TI2V为什么只训练中间的dit,text ecoder和vae好像都没训?看不太懂这个:

Image

xuxiaoxxxx avatar Oct 29 '25 10:10 xuxiaoxxxx

问题原因在于vae的feat_cache和zero3冲突,但是feat_cache的创建位置太零碎,还有各种shape操作,巨难改。。。 我的解决方案,不优雅: 模型加载后把vae挪到外面WanTrainingModule里,pipe.vae置空,然后accelerate.prepare传入model.pipe,而不是model,另外需要修改所有用到vae的PipeUnit,把vae作为参数传进去,例如: Image vae就成功脱离zero3了! 注:得记得手动把vae挪到正确的device上

所以对于全量训练来说,TI2V为什么只训练中间的dit,text ecoder和vae好像都没训?看不太懂这个:

Image

一般来说sft是只训dit的呀,具体训练流程可以看看论文

MaxwellDing avatar Oct 31 '25 02:10 MaxwellDing