2024-04-25 10:32:01.418766: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /content/kohya-trainer/finetune/make_captions.py:14 in │
│ │
│ 11 from torchvision import transforms │
│ 12 from torchvision.transforms.functional import InterpolationMode │
│ 13 from blip.blip import blip_decoder │
│ ❱ 14 import library.train_util as train_util │
│ 15 │
│ 16 DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') │
│ 17 │
│ │
│ /content/kohya-trainer/library/train_util.py:37 in │
│ │
│ 34 from torchvision import transforms │
│ 35 from transformers import CLIPTokenizer │
│ 36 import transformers │
│ ❱ 37 import diffusers │
│ 38 from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION │
│ 39 from diffusers import ( │
│ 40 │ StableDiffusionPipeline, │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/init.py:38 in │
│ │
│ 35 │ │ get_polynomial_decay_schedule_with_warmup, │
│ 36 │ │ get_scheduler, │
│ 37 │ ) │
│ ❱ 38 │ from .pipeline_utils import DiffusionPipeline │
│ 39 │ from .pipelines import ( │
│ 40 │ │ DanceDiffusionPipeline, │
│ 41 │ │ DDIMPipeline, │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/pipeline_utils.py:38 in │
│ │
│ 35 from .dynamic_modules_utils import get_class_from_dynamic_module │
│ 36 from .hub_utils import http_user_agent, send_telemetry │
│ 37 from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT │
│ ❱ 38 from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME │
│ 39 from .utils import ( │
│ 40 │ CONFIG_NAME, │
│ 41 │ DIFFUSERS_CACHE, │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/schedulers/init.py:50 in │
│ │
│ 47 │ from ..utils.dummy_flax_objects import * # noqa F403 │
│ 48 else: │
│ 49 │ from .scheduling_ddim_flax import FlaxDDIMScheduler │
│ ❱ 50 │ from .scheduling_ddpm_flax import FlaxDDPMScheduler │
│ 51 │ from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler │
│ 52 │ from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler │
│ 53 │ from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/schedulers/scheduling_ddpm_flax.py:80 in │
│ │
│ │
│ 77 │ state: DDPMSchedulerState │
│ 78 │
│ 79 │
│ ❱ 80 class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): │
│ 81 │ """ │
│ 82 │ Denoising diffusion probabilistic models (DDPMs) explores the connections between de │
│ 83 │ Langevin dynamics sampling. │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/schedulers/scheduling_ddpm_flax.py:216 in │
│ FlaxDDPMScheduler │
│ │
│ 213 │ │ model_output: jnp.ndarray, │
│ 214 │ │ timestep: int, │
│ 215 │ │ sample: jnp.ndarray, │
│ ❱ 216 │ │ key: random.KeyArray, │
│ 217 │ │ return_dict: bool = True, │
│ 218 │ │ **kwargs, │
│ 219 │ ) -> Union[FlaxDDPMSchedulerOutput, Tuple]: │
│ │
│ /usr/local/lib/python3.10/dist-packages/jax/_src/deprecations.py:54 in getattr │
│ │
│ 51 │ │ raise AttributeError(message) │
│ 52 │ warnings.warn(message, DeprecationWarning, stacklevel=2) │
│ 53 │ return fn │
│ ❱ 54 │ raise AttributeError(f"module {module!r} has no attribute {name!r}") │
│ 55 │
│ 56 return getattr │
│ 57 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AttributeError: module 'jax.random' has no attribute 'KeyArray'
Not sure if you figured this out, but jax/jaxlib 0.4.23 and before should work.