diffusers
diffusers copied to clipboard
Training process fails with a Jax library related issue
Describe the bug
Training process fails with a Jax library related issue.
This the the python code in the notebook cell, that fails:
!python3 train_dreambooth.py
--pretrained_model_name_or_path=$MODEL_NAME
--pretrained_vae_name_or_path="stabilityai/sd-vae-ft-mse"
--output_dir=$OUTPUT_DIR
--with_prior_preservation --prior_loss_weight=1.0
--seed=1337
--resolution=512
--train_batch_size=1
--train_text_encoder
--mixed_precision="fp16"
--use_8bit_adam
--gradient_accumulation_steps=1
--learning_rate=1e-6
--lr_scheduler="constant"
--lr_warmup_steps=0
--num_class_images=50
--sample_batch_size=4
--max_train_steps=800
--save_interval=10000
--save_sample_prompt="photo of narrow gate"
--concepts_list="concepts_list.json"
Attached is the screenshot for the error:
Reproduction
Run the training process by issuing the following command:
!python3 train_dreambooth.py
--pretrained_model_name_or_path=$MODEL_NAME
--pretrained_vae_name_or_path="stabilityai/sd-vae-ft-mse"
--output_dir=$OUTPUT_DIR
--with_prior_preservation --prior_loss_weight=1.0
--seed=1337
--resolution=512
--train_batch_size=1
--train_text_encoder
--mixed_precision="fp16"
--use_8bit_adam
--gradient_accumulation_steps=1
--learning_rate=1e-6
--lr_scheduler="constant"
--lr_warmup_steps=0
--num_class_images=50
--sample_batch_size=4
--max_train_steps=800
--save_interval=10000
--save_sample_prompt="photo of narrow gate"
--concepts_list="concepts_list.json"
Logs
No response
System Info
I am running this on a google colab runtime on a python 3 running on a Google compute engine with a Tesla GPU.
Install details:
!wget -q https://github.com/ShivamShrirao/diffusers/raw/main/examples/dreambooth/train_dreambooth.py !wget -q https://github.com/ShivamShrirao/diffusers/raw/main/scripts/convert_diffusers_to_original_stable_diffusion.py %pip install -qq git+https://github.com/ShivamShrirao/diffusers %pip install -q -U --pre triton %pip install -q accelerate transformers ftfy bitsandbytes==0.35.0 gradio natsort safetensors xformers
Any Update ? facing same issue
this seems to work:
!pip install "jax[cuda12_local]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Google always ends up ruining something in each update, you need to use a specific version:
!pip install jax==0.4.19 jaxlib==0.4.19 -f https://storage.googleapis.com/jax-releases/jax_releases.html
This solves the problem for now.
Indeed !pip install jax==0.4.19 jaxlib==0.4.19 -f https://storage.googleapis.com/jax-releases/jax_releases.html
is solving this issue, but there is another issue comes RuntimeError: operator torchvision::nms does not exist
:
Traceback (most recent call last):
File "/content/train_dreambooth.py", line 26, in <module>
from torchvision import transforms
File "/usr/local/lib/python3.10/dist-packages/torchvision/__init__.py", line 6, in <module>
from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils
File "/usr/local/lib/python3.10/dist-packages/torchvision/_meta_registrations.py", line 164, in <module>
def meta_nms(dets, scores, iou_threshold):
File "/usr/local/lib/python3.10/dist-packages/torch/library.py", line 467, in inner
handle = entry.abstract_impl.register(func_to_register, source)
File "/usr/local/lib/python3.10/dist-packages/torch/_library/abstract_impl.py", line 30, in register
if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"):
RuntimeError: operator torchvision::nms does not exist
probably the PyTorch version should be fixed too, but which version?
Indeed
!pip install jax==0.4.19 jaxlib==0.4.19 -f https://storage.googleapis.com/jax-releases/jax_releases.html
is solving this issue, but there is another issue comesRuntimeError: operator torchvision::nms does not exist
:Traceback (most recent call last): File "/content/train_dreambooth.py", line 26, in <module> from torchvision import transforms File "/usr/local/lib/python3.10/dist-packages/torchvision/__init__.py", line 6, in <module> from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils File "/usr/local/lib/python3.10/dist-packages/torchvision/_meta_registrations.py", line 164, in <module> def meta_nms(dets, scores, iou_threshold): File "/usr/local/lib/python3.10/dist-packages/torch/library.py", line 467, in inner handle = entry.abstract_impl.register(func_to_register, source) File "/usr/local/lib/python3.10/dist-packages/torch/_library/abstract_impl.py", line 30, in register if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"): RuntimeError: operator torchvision::nms does not exist
probably the PyTorch version should be fixed too, but which version?
I have the same problem. Did anyone find a solution?
@roman19932024 try to update python version to 3.10.