diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Training process fails with a Jax library related issue

Open randheerDas opened this issue 10 months ago • 6 comments

Describe the bug

Training process fails with a Jax library related issue.

This the the python code in the notebook cell, that fails:

!python3 train_dreambooth.py
--pretrained_model_name_or_path=$MODEL_NAME
--pretrained_vae_name_or_path="stabilityai/sd-vae-ft-mse"
--output_dir=$OUTPUT_DIR
--with_prior_preservation --prior_loss_weight=1.0
--seed=1337
--resolution=512
--train_batch_size=1
--train_text_encoder
--mixed_precision="fp16"
--use_8bit_adam
--gradient_accumulation_steps=1
--learning_rate=1e-6
--lr_scheduler="constant"
--lr_warmup_steps=0
--num_class_images=50
--sample_batch_size=4
--max_train_steps=800
--save_interval=10000
--save_sample_prompt="photo of narrow gate"
--concepts_list="concepts_list.json"

Attached is the screenshot for the error:

Error

Reproduction

Run the training process by issuing the following command:

!python3 train_dreambooth.py
--pretrained_model_name_or_path=$MODEL_NAME
--pretrained_vae_name_or_path="stabilityai/sd-vae-ft-mse"
--output_dir=$OUTPUT_DIR
--with_prior_preservation --prior_loss_weight=1.0
--seed=1337
--resolution=512
--train_batch_size=1
--train_text_encoder
--mixed_precision="fp16"
--use_8bit_adam
--gradient_accumulation_steps=1
--learning_rate=1e-6
--lr_scheduler="constant"
--lr_warmup_steps=0
--num_class_images=50
--sample_batch_size=4
--max_train_steps=800
--save_interval=10000
--save_sample_prompt="photo of narrow gate"
--concepts_list="concepts_list.json"

Logs

No response

System Info

I am running this on a google colab runtime on a python 3 running on a Google compute engine with a Tesla GPU.

Install details:

!wget -q https://github.com/ShivamShrirao/diffusers/raw/main/examples/dreambooth/train_dreambooth.py !wget -q https://github.com/ShivamShrirao/diffusers/raw/main/scripts/convert_diffusers_to_original_stable_diffusion.py %pip install -qq git+https://github.com/ShivamShrirao/diffusers %pip install -q -U --pre triton %pip install -q accelerate transformers ftfy bitsandbytes==0.35.0 gradio natsort safetensors xformers

randheerDas avatar Apr 12 '24 01:04 randheerDas

Any Update ? facing same issue

mahaboobkhan29 avatar Apr 16 '24 13:04 mahaboobkhan29

this seems to work:

!pip install "jax[cuda12_local]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

The-Ramosian avatar Apr 21 '24 19:04 The-Ramosian

Google always ends up ruining something in each update, you need to use a specific version: !pip install jax==0.4.19 jaxlib==0.4.19 -f https://storage.googleapis.com/jax-releases/jax_releases.html

This solves the problem for now.

JossCamp avatar Apr 22 '24 21:04 JossCamp

Indeed !pip install jax==0.4.19 jaxlib==0.4.19 -f https://storage.googleapis.com/jax-releases/jax_releases.html is solving this issue, but there is another issue comes RuntimeError: operator torchvision::nms does not exist:

Traceback (most recent call last):
  File "/content/train_dreambooth.py", line 26, in <module>
    from torchvision import transforms
  File "/usr/local/lib/python3.10/dist-packages/torchvision/__init__.py", line 6, in <module>
    from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils
  File "/usr/local/lib/python3.10/dist-packages/torchvision/_meta_registrations.py", line 164, in <module>
    def meta_nms(dets, scores, iou_threshold):
  File "/usr/local/lib/python3.10/dist-packages/torch/library.py", line 467, in inner
    handle = entry.abstract_impl.register(func_to_register, source)
  File "/usr/local/lib/python3.10/dist-packages/torch/_library/abstract_impl.py", line 30, in register
    if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"):
RuntimeError: operator torchvision::nms does not exist

probably the PyTorch version should be fixed too, but which version?

mirodil-ml avatar May 09 '24 07:05 mirodil-ml

Indeed !pip install jax==0.4.19 jaxlib==0.4.19 -f https://storage.googleapis.com/jax-releases/jax_releases.html is solving this issue, but there is another issue comes RuntimeError: operator torchvision::nms does not exist:

Traceback (most recent call last):
  File "/content/train_dreambooth.py", line 26, in <module>
    from torchvision import transforms
  File "/usr/local/lib/python3.10/dist-packages/torchvision/__init__.py", line 6, in <module>
    from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils
  File "/usr/local/lib/python3.10/dist-packages/torchvision/_meta_registrations.py", line 164, in <module>
    def meta_nms(dets, scores, iou_threshold):
  File "/usr/local/lib/python3.10/dist-packages/torch/library.py", line 467, in inner
    handle = entry.abstract_impl.register(func_to_register, source)
  File "/usr/local/lib/python3.10/dist-packages/torch/_library/abstract_impl.py", line 30, in register
    if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"):
RuntimeError: operator torchvision::nms does not exist

probably the PyTorch version should be fixed too, but which version?

I have the same problem. Did anyone find a solution?

roman19932024 avatar May 18 '24 11:05 roman19932024

@roman19932024 try to update python version to 3.10.

mirodil-ml avatar May 20 '24 16:05 mirodil-ml