diffusers
diffusers copied to clipboard
AttributeError: module 'jax.random' has no attribute 'KeyArray' while running DreamBooth_Stable_Diffusion.ipynb
Describe the bug
Running train_dreambooth.py
on Google Colab throws the following error
AttributeError: module 'jax.random' has no attribute 'KeyArray'
Reproduction
- Navigate to https://github.com/ShivamShrirao/diffusers/tree/main/examples/dreambooth
- Click and Copy the google colab notebook
- Run it and you will see error while running
!python3 train_dreambooth.py ....
step
Logs
Traceback (most recent call last):
File "/content/train_dreambooth.py", line 21, in <module>
from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
File "/usr/local/lib/python3.10/dist-packages/diffusers/__init__.py", line 36, in <module>
from .models import (
File "/usr/local/lib/python3.10/dist-packages/diffusers/models/__init__.py", line 33, in <module>
from .controlnet_flax import FlaxControlNetModel
File "/usr/local/lib/python3.10/dist-packages/diffusers/models/controlnet_flax.py", line 25, in <module>
from .modeling_flax_utils import FlaxModelMixin
File "/usr/local/lib/python3.10/dist-packages/diffusers/models/modeling_flax_utils.py", line 45, in <module>
class FlaxModelMixin:
File "/usr/local/lib/python3.10/dist-packages/diffusers/models/modeling_flax_utils.py", line 192, in FlaxModelMixin
def init_weights(self, rng: jax.random.KeyArray) -> Dict:
File "/usr/local/lib/python3.10/dist-packages/jax/_src/deprecations.py", line 54, in getattr
raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.random' has no attribute 'KeyArray'
### System Info
Latest google colab