Versioning issue between jax and diffusion packages
Hi,
I have been trying to run the colab connected to a Python3 T4 GPU runtime and I am seeing incompatibility issues between the different package versions. I f I run the colab as it is:
# Installing PyTorch with CUDA support (matching Colab's CUDA version, usually the latest supported by PyTorch)
!pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
# Installing Diffusers and Transformers
!pip install diffusers==0.8.0 transformers
# Installing essential and commonly used libraries
!pip install numpy pandas scipy scikit-learn matplotlib opencv-python-headless
# Additional utilities that might be useful
!pip install tqdm requests pillow
!pip install "jax[cuda12_local]==0.5.3" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install wandb
!pip install accelerate
then, I see the following during installation:
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
chex 0.1.89 requires jax>=0.4.27, but you have jax 0.4.23 which is incompatible.
chex 0.1.89 requires jaxlib>=0.4.27, but you have jaxlib 0.4.23+cuda12.cudnn89 which is incompatible.
optax 0.2.4 requires jax>=0.4.27, but you have jax 0.4.23 which is incompatible.
optax 0.2.4 requires jaxlib>=0.4.27, but you have jaxlib 0.4.23+cuda12.cudnn89 which is incompatible.
flax 0.10.4 requires jax>=0.4.27, but you have jax 0.4.23 which is incompatible.
orbax-checkpoint 0.11.10 requires jax>=0.5.0, but you have jax 0.4.23 which is incompatible.
and when trying to import some of the diffusion modules:
from diffusers import StableDiffusionPipeline, DDIMScheduler
I see the following error:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
[<ipython-input-5-5f3e3eccac57>](https://localhost:8080/#) in <cell line: 0>()
----> 1 from diffusers import StableDiffusionPipeline, DDIMScheduler
12 frames
[/usr/local/lib/python3.11/dist-packages/diffusers/__init__.py](https://localhost:8080/#) in <module>
19 if is_torch_available():
20 from .modeling_utils import ModelMixin
---> 21 from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
22 from .optimization import (
23 get_constant_schedule,
[/usr/local/lib/python3.11/dist-packages/diffusers/models/__init__.py](https://localhost:8080/#) in <module>
24
25 if is_flax_available():
---> 26 from .unet_2d_condition_flax import FlaxUNet2DConditionModel
27 from .vae_flax import FlaxAutoencoderKL
[/usr/local/lib/python3.11/dist-packages/diffusers/models/unet_2d_condition_flax.py](https://localhost:8080/#) in <module>
14 from typing import Tuple, Union
15
---> 16 import flax
17 import flax.linen as nn
18 import jax
[/usr/local/lib/python3.11/dist-packages/flax/__init__.py](https://localhost:8080/#) in <module>
22 del configurations
23
---> 24 from flax import core
25 from flax import jax_utils
26 from flax import linen
[/usr/local/lib/python3.11/dist-packages/flax/core/__init__.py](https://localhost:8080/#) in <module>
22 unfreeze as unfreeze,
23 )
---> 24 from .lift import (
25 custom_vjp as custom_vjp,
26 jit as jit,
[/usr/local/lib/python3.11/dist-packages/flax/core/lift.py](https://localhost:8080/#) in <module>
25
26 from flax import traceback_util
---> 27 from flax import traverse_util
28 from flax.typing import (
29 In,
[/usr/local/lib/python3.11/dist-packages/flax/traverse_util.py](https://localhost:8080/#) in <module>
64
65 import flax
---> 66 from flax.core.scope import VariableDict
67 from flax.typing import PathParts
68
[/usr/local/lib/python3.11/dist-packages/flax/core/scope.py](https://localhost:8080/#) in <module>
53 )
54
---> 55 from . import meta, partial_eval, tracers
56 from .frozen_dict import FrozenDict, freeze, unfreeze
57
[/usr/local/lib/python3.11/dist-packages/flax/core/meta.py](https://localhost:8080/#) in <module>
186
187
--> 188 class Partitioned(struct.PyTreeNode, AxisMetadata[A]):
189 """Wrapper for partitioning metadata.
190
/usr/lib/python3.11/abc.py in __new__(mcls, name, bases, namespace, **kwargs)
[/usr/local/lib/python3.11/dist-packages/flax/struct.py](https://localhost:8080/#) in __init_subclass__(cls, **kwargs)
233
234 def __init_subclass__(cls, **kwargs):
--> 235 dataclass(cls, **kwargs) # pytype: disable=wrong-arg-types
236
237 def __init__(self, *args, **kwargs):
[/usr/local/lib/python3.11/dist-packages/flax/struct.py](https://localhost:8080/#) in dataclass(clz, **kwargs)
148 data_clz.replace = replace
149
--> 150 jax.tree_util.register_dataclass(data_clz, data_fields, meta_fields)
151
152 def to_state_dict(x):
[/usr/local/lib/python3.11/dist-packages/jax/_src/deprecations.py](https://localhost:8080/#) in getattr(name)
51 warnings.warn(message, DeprecationWarning, stacklevel=2)
52 return fn
---> 53 raise AttributeError(f"module {module!r} has no attribute {name!r}")
54
55 return getattr
AttributeError: module 'jax.tree_util' has no attribute 'register_dataclass'
I have tried upgrading jax to jax-0.5.3 and jaxlib-0.5.3, but then I see this error instead:
---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
[<ipython-input-2-5f3e3eccac57>](https://localhost:8080/#) in <cell line: 0>()
----> 1 from diffusers import StableDiffusionPipeline, DDIMScheduler
2 frames
[/usr/local/lib/python3.11/dist-packages/diffusers/__init__.py](https://localhost:8080/#) in <module>
29 get_scheduler,
30 )
---> 31 from .pipeline_utils import DiffusionPipeline
32 from .pipelines import (
33 DanceDiffusionPipeline,
[/usr/local/lib/python3.11/dist-packages/diffusers/pipeline_utils.py](https://localhost:8080/#) in <module>
33
34 from .configuration_utils import ConfigMixin
---> 35 from .dynamic_modules_utils import get_class_from_dynamic_module
36 from .hub_utils import http_user_agent
37 from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
[/usr/local/lib/python3.11/dist-packages/diffusers/dynamic_modules_utils.py](https://localhost:8080/#) in <module>
24 from typing import Dict, Optional, Union
25
---> 26 from huggingface_hub import HfFolder, cached_download, hf_hub_download, model_info
27
28 from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
ImportError: cannot import name 'cached_download' from 'huggingface_hub' (/usr/local/lib/python3.11/dist-packages/huggingface_hub/__init__.py)
I also tried upgrading jax and diffusers (to diffusers-0.32.2) but then in the last cells I encounter:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
[<ipython-input-12-1772ee7a7d70>](https://localhost:8080/#) in <cell line: 0>()
----> 1 embedding = optimize_embedding(
2 ldm,
3 num_steps=num_optimization_steps,
4 batch_size=batch_size,
5 top_k = num_keypoints,
10 frames
[<ipython-input-11-aead50517aea>](https://localhost:8080/#) in optimize_embedding(ldm, context, device, num_steps, from_where, upsample_res, layers, lr, noise_level, num_tokens, top_k, augment_degrees, augment_scale, augment_translate, dataset_loc, sigma, sharpening_loss_weight, equivariance_attn_loss_weight, batch_size, num_gpus, max_len, min_dist, furthest_point_num_samples, controllers, validation, num_subjects)
76 image = mini_batch["img"]
77
---> 78 attn_maps = run_and_find_attn(
79 ldm,
80 image,
[<ipython-input-10-d782888e078d>](https://localhost:8080/#) in run_and_find_attn(ldm, image, context, noise_level, device, from_where, layers, upsample_res, indices, controllers)
113 controllers=None,
114 ):
--> 115 _, _ = find_pred_noise(
116 ldm,
117 image,
[<ipython-input-10-d782888e078d>](https://localhost:8080/#) in find_pred_noise(ldm, image, context, noise_level, device)
41 # import ipdb; ipdb.set_trace()
42
---> 43 pred_noise = ldm.unet(noisy_image,
44 ldm.scheduler.timesteps[noise_level].repeat(noisy_image.shape[0]),
45 context.repeat(noisy_image.shape[0], 1, 1))["sample"]
[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
1740
1741 # torchrec tests the code consistency with the following code
[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
1748 or _global_backward_pre_hooks or _global_backward_hooks
1749 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750 return forward_call(*args, **kwargs)
1751
1752 result = None
[/usr/local/lib/python3.11/dist-packages/torch/nn/parallel/data_parallel.py](https://localhost:8080/#) in forward(self, *inputs, **kwargs)
189
190 if len(self.device_ids) == 1:
--> 191 return self.module(*inputs[0], **module_kwargs[0])
192 replicas = self.replicate(self.module, self.device_ids[: len(inputs)])
193 outputs = self.parallel_apply(replicas, inputs, module_kwargs)
[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
1740
1741 # torchrec tests the code consistency with the following code
[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
1843
1844 try:
-> 1845 return inner()
1846 except Exception:
1847 # run always called hooks if they have not already been run
[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in inner()
1780 )
1781 else:
-> 1782 args_result = hook(self, args)
1783 if args_result is not None:
1784 if not isinstance(args_result, tuple):
[<ipython-input-6-72ab3767eb92>](https://localhost:8080/#) in hook_fn(module, input)
200 _device = input[0].device
201 # if device not in patched_devices:
--> 202 register_attention_control(module, controllers[_device], feature_upsample_res=feature_upsample_res)
203 # patched_devices.add(device)
204
[<ipython-input-6-72ab3767eb92>](https://localhost:8080/#) in register_attention_control(model, controller, feature_upsample_res)
160
161 # create assertion with message
--> 162 assert cross_att_count != 0, "No cross attention layers found in the model. Please check to make sure you're using diffusers==0.8.0."
163
164 def load_ldm(device, type="CompVis/stable-diffusion-v1-4", feature_upsample_res=256):
AssertionError: No cross attention layers found in the model. Please check to make sure you're using diffusers==0.8.0.
Can you please recommend which versions to use for each package? Thanks!!!!
I have fixed the installs but it seems stable diffusion isn't hosted publicly on huggingface anymore. It looks likely like the model will need to be installed locally and the path would need to be passed in instead of "runwayml/stable-diffusion-v1-5"