OneTrainer
OneTrainer copied to clipboard
[Bug]: SDXL DoRA training fails to start when base model weight data type is set to float8
What happened?
When I activate Decomposed Weights (DoRA) training in "Lora" tab and have base SDXL model loaded as float8 in "model" tab the training process fails to start with RuntimeError: Promotion for Float8 Types is not supported, attempted to promote Float8_e4m3fn and Half.
If Dora toggle is deactivated training starts as usual.
What did you expect would happen?
DoRA training working.
Relevant log output
epoch: 0%| | 0/100 [00:26<?, ?it/s]
Traceback (most recent call last):
File "/home/user/apps/OneTrainer/scripts/train.py", line 38, in <module>
main()
File "/home/user/apps/OneTrainer/scripts/train.py", line 29, in main
trainer.train()
File "/home/user/apps/OneTrainer/modules/trainer/GenericTrainer.py", line 575, in train
model_output_data = self.model_setup.predict(self.model, batch, self.config, train_progress)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/apps/OneTrainer/modules/modelSetup/BaseStableDiffusionXLSetup.py", line 467, in predict
predicted_latent_noise = model.unet(
^^^^^^^^^^^
File "/home/user/apps/OneTrainer/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/apps/OneTrainer/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/apps/OneTrainer/venv/src/diffusers/src/diffusers/models/unets/unet_2d_condition.py", line 1135, in forward
emb = self.time_embedding(t_emb, timestep_cond)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/apps/OneTrainer/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/apps/OneTrainer/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/apps/OneTrainer/venv/src/diffusers/src/diffusers/models/embeddings.py", line 376, in forward
sample = self.linear_1(sample)
^^^^^^^^^^^^^^^^^^^^^
File "/home/user/apps/OneTrainer/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/apps/OneTrainer/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/apps/OneTrainer/modules/module/LoRAModule.py", line 374, in forward
WP = self.orig_module.weight + (self.make_weight(A, B) * (self.alpha / self.rank))
~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
RuntimeError: Promotion for Float8 Types is not supported, attempted to promote Float8_e4m3fn and Half
Output of pip freeze
No response