OneTrainer icon indicating copy to clipboard operation
OneTrainer copied to clipboard

[Bug]: SDXL DoRA training fails to start when base model weight data type is set to float8

Open ikarsokolov opened this issue 6 months ago • 4 comments

What happened?

When I activate Decomposed Weights (DoRA) training in "Lora" tab and have base SDXL model loaded as float8 in "model" tab the training process fails to start with RuntimeError: Promotion for Float8 Types is not supported, attempted to promote Float8_e4m3fn and Half.

If Dora toggle is deactivated training starts as usual.

What did you expect would happen?

DoRA training working.

Relevant log output

epoch:   0%|                                                                                                                                                          | 0/100 [00:26<?, ?it/s]
Traceback (most recent call last):                                                                                                                                                            
  File "/home/user/apps/OneTrainer/scripts/train.py", line 38, in <module>                                                                                                               
    main()                                                                                                                                                                                    
  File "/home/user/apps/OneTrainer/scripts/train.py", line 29, in main                                                                                                                   
    trainer.train()                                                                                                                                                                           
  File "/home/user/apps/OneTrainer/modules/trainer/GenericTrainer.py", line 575, in train                                                                                                
    model_output_data = self.model_setup.predict(self.model, batch, self.config, train_progress)                                                                                              
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                              
  File "/home/user/apps/OneTrainer/modules/modelSetup/BaseStableDiffusionXLSetup.py", line 467, in predict                                                                               
    predicted_latent_noise = model.unet(                                                                                                                                                      
                             ^^^^^^^^^^^                                                                                                                                                      
  File "/home/user/apps/OneTrainer/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl                                                       
    return self._call_impl(*args, **kwargs)                                                                                                                                                   
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                   
  File "/home/user/apps/OneTrainer/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl                                                               
    return forward_call(*args, **kwargs)                                                                                                                                                      
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                      
  File "/home/user/apps/OneTrainer/venv/src/diffusers/src/diffusers/models/unets/unet_2d_condition.py", line 1135, in forward                                                            
    emb = self.time_embedding(t_emb, timestep_cond)                                                                                                                                           
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                           
  File "/home/user/apps/OneTrainer/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl                                                       
    return self._call_impl(*args, **kwargs)                                                                                                                                                   
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                   
  File "/home/user/apps/OneTrainer/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl                                                               
    return forward_call(*args, **kwargs)                                                                                                                                                      
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                      
  File "/home/user/apps/OneTrainer/venv/src/diffusers/src/diffusers/models/embeddings.py", line 376, in forward                                                                          
    sample = self.linear_1(sample)                                                                                                                                                            
             ^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                            
  File "/home/user/apps/OneTrainer/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl                                                       
    return self._call_impl(*args, **kwargs)                                                                                                                                                   
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                   
  File "/home/user/apps/OneTrainer/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/apps/OneTrainer/modules/module/LoRAModule.py", line 374, in forward
    WP = self.orig_module.weight + (self.make_weight(A, B) * (self.alpha / self.rank))
         ~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
RuntimeError: Promotion for Float8 Types is not supported, attempted to promote Float8_e4m3fn and Half

Output of pip freeze

No response

ikarsokolov avatar Aug 21 '24 03:08 ikarsokolov