diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

AttributeError: module 'jax.random' has no attribute 'KeyArray' while running DreamBooth_Stable_Diffusion.ipynb

Open mirodil-ml opened this issue 9 months ago • 1 comments

Describe the bug

Running train_dreambooth.py on Google Colab throws the following error

AttributeError: module 'jax.random' has no attribute 'KeyArray'

Reproduction

  • Navigate to https://github.com/ShivamShrirao/diffusers/tree/main/examples/dreambooth
  • Click and Copy the google colab notebook
  • Run it and you will see error while running !python3 train_dreambooth.py .... step

Logs

Traceback (most recent call last):
  File "/content/train_dreambooth.py", line 21, in <module>
    from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
  File "/usr/local/lib/python3.10/dist-packages/diffusers/__init__.py", line 36, in <module>
    from .models import (
  File "/usr/local/lib/python3.10/dist-packages/diffusers/models/__init__.py", line 33, in <module>
    from .controlnet_flax import FlaxControlNetModel
  File "/usr/local/lib/python3.10/dist-packages/diffusers/models/controlnet_flax.py", line 25, in <module>
    from .modeling_flax_utils import FlaxModelMixin
  File "/usr/local/lib/python3.10/dist-packages/diffusers/models/modeling_flax_utils.py", line 45, in <module>
    class FlaxModelMixin:
  File "/usr/local/lib/python3.10/dist-packages/diffusers/models/modeling_flax_utils.py", line 192, in FlaxModelMixin
    def init_weights(self, rng: jax.random.KeyArray) -> Dict:
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/deprecations.py", line 54, in getattr
    raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.random' has no attribute 'KeyArray'


### System Info

Latest google colab

mirodil-ml avatar May 08 '24 00:05 mirodil-ml