ray
ray copied to clipboard
[train] Pin jax for Dreambooth Fine-Tuning template
Why are these changes needed?
More recent versions of jax (e.g. 0.4.28) will cause the following problem:
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/ray/default/ray/doc/source/templates/05_dreambooth_finetuning/dreambooth/generate.py:9 in │
│ <module> │
│ │
│ 6 import ray │
│ 7 │
│ 8 from flags import run_model_flags │
│ ❱ 9 from generate_utils import get_pipeline │
│ 10 │
│ 11 │
│ 12 def run(args): │
│ │
│ /home/ray/default/ray/doc/source/templates/05_dreambooth_finetuning/dreambooth/generate_utils.py │
│ :1 in <module> │
│ │
│ ❱ 1 from diffusers import DiffusionPipeline │
│ 2 from diffusers.loaders import LoraLoaderMixin │
│ 3 import torch │
│ 4 │
│ │
│ /home/ray/anaconda3/lib/python3.10/site-packages/diffusers/__init__.py:38 in <module> │
│ │
│ 35 except OptionalDependencyNotAvailable: │
│ 36 │ from .utils.dummy_pt_objects import * # noqa F403 │
│ 37 else: │
│ ❱ 38 │ from .models import ( │
│ 39 │ │ AsymmetricAutoencoderKL, │
│ 40 │ │ AutoencoderKL, │
│ 41 │ │ ControlNetModel, │
│ │
│ /home/ray/anaconda3/lib/python3.10/site-packages/diffusers/models/__init__.py:35 in <module> │
│ │
│ 32 │ from .vq_model import VQModel │
│ 33 │
│ 34 if is_flax_available(): │
│ ❱ 35 │ from .controlnet_flax import FlaxControlNetModel │
│ 36 │ from .unet_2d_condition_flax import FlaxUNet2DConditionModel │
│ 37 │ from .vae_flax import FlaxAutoencoderKL │
│ 38 │
│ │
│ /home/ray/anaconda3/lib/python3.10/site-packages/diffusers/models/controlnet_flax.py:16 in │
│ <module> │
│ │
│ 13 # limitations under the License. │
│ 14 from typing import Optional, Tuple, Union │
│ 15 │
│ ❱ 16 import flax │
│ 17 import flax.linen as nn │
│ 18 import jax │
│ 19 import jax.numpy as jnp │
│ │
│ /home/ray/anaconda3/lib/python3.10/site-packages/flax/__init__.py:18 in <module> │
│ │
│ 15 │
│ 16 """Flax API.""" │
│ 17 │
│ ❱ 18 from .configurations import ( │
│ 19 │ config as config, │
│ 20 ) │
│ 21 │
│ │
│ /home/ray/anaconda3/lib/python3.10/site-packages/flax/configurations.py:92 in <module> │
│ │
│ 89 # Whether to use the lazy rng implementation. │
│ 90 flax_lazy_rng = static_bool_env('FLAX_LAZY_RNG', True) │
│ 91 │
│ ❱ 92 flax_filter_frames = define_bool_state( │
│ 93 │ name='filter_frames', │
│ 94 │ default=True, │
│ 95 │ help=('Whether to hide flax-internal stack frames from tracebacks.')) │
│ │
│ /home/ray/anaconda3/lib/python3.10/site-packages/flax/configurations.py:42 in define_bool_state │
│ │
│ 39 'FLAX_<UPPERCASE_NAME>'. JAX config ensures that the flag can be overwritten │
│ 40 on runtime with `flax.config.update('flax_<config_name>', <value>)`. │
│ 41 """ │
│ ❱ 42 return jax_config.define_bool_state('flax_' + name, default, help) │
│ 43 │
│ 44 │
│ 45 def static_bool_env(varname: str, default: bool) -> bool: │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AttributeError: 'Config' object has no attribute 'define_bool_state'
Related issue number
Checks
- [ ] I've signed off every commit(by using the -s flag, i.e.,
git commit -s) in this PR. - [ ] I've run
scripts/format.shto lint the changes in this PR. - [ ] I've included any doc changes needed for https://docs.ray.io/en/master/.
- [ ] I've added any new APIs to the API Reference. For example, if I added a
method in Tune, I've added it in
doc/source/tune/api/under the corresponding.rstfile.
- [ ] I've added any new APIs to the API Reference. For example, if I added a
method in Tune, I've added it in
- [ ] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
- Testing Strategy
- [ ] Unit tests
- [ ] Release tests
- [ ] This PR is not tested :(