sd-scripts
sd-scripts copied to clipboard
SD3.5 finetuning crashes due to dtype difference in vae call when not caching latents (potentially DDP related?)
[rank1]: Traceback (most recent call last):
[rank1]: File "/dockercontainer/sd-scripts/sd3_train.py", line 1200, in <module>
[rank1]: train(args)
[rank1]: File "/dockercontainer/sd-scripts/sd3_train.py", line 871, in train
[rank1]: latents = vae.encode(batch["images"])
[rank1]: File "/dockercontainer/sd-scripts/library/sd3_models.py", line 1435, in encode
[rank1]: hidden = self.encoder(image)
[rank1]: File "/dockercontainer/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/dockercontainer/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "/dockercontainer/sd-scripts/library/sd3_models.py", line 1333, in forward
[rank1]: hs = [self.conv_in(x)]
[rank1]: File "/dockercontainer/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/dockercontainer/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "/dockercontainer/venv/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 554, in forward
[rank1]: return self._conv_forward(input, self.weight, self.bias)
[rank1]: File "/dockercontainer/venv/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 549, in _conv_forward
[rank1]: return F.conv2d(
RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same
This also happens with T5 dropout, i think
Fixed VAE issue. I can't reproduce the issue with T5XXL, so please let me know how to do it.
@kohya-ss T5xxl dtype mismatch happens with bs>1 and t5xxl dropout enabled, maybe ddp too thoughits a stretch, ill try and get error logs later today
hi can u tell me how u ran the code and how u kept the data for training as I am not able to understand it