ray icon indicating copy to clipboard operation
ray copied to clipboard

[train] Pin jax for Dreambooth Fine-Tuning template

Open matthewdeng opened this issue 1 year ago • 0 comments

Why are these changes needed?

More recent versions of jax (e.g. 0.4.28) will cause the following problem:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/ray/default/ray/doc/source/templates/05_dreambooth_finetuning/dreambooth/generate.py:9 in  │
│ <module>                                                                                         │
│                                                                                                  │
│    6 import ray                                                                                  │
│    7                                                                                             │
│    8 from flags import run_model_flags                                                           │
│ ❱  9 from generate_utils import get_pipeline                                                     │
│   10                                                                                             │
│   11                                                                                             │
│   12 def run(args):                                                                              │
│                                                                                                  │
│ /home/ray/default/ray/doc/source/templates/05_dreambooth_finetuning/dreambooth/generate_utils.py │
│ :1 in <module>                                                                                   │
│                                                                                                  │
│ ❱  1 from diffusers import DiffusionPipeline                                                     │
│    2 from diffusers.loaders import LoraLoaderMixin                                               │
│    3 import torch                                                                                │
│    4                                                                                             │
│                                                                                                  │
│ /home/ray/anaconda3/lib/python3.10/site-packages/diffusers/__init__.py:38 in <module>            │
│                                                                                                  │
│    35 except OptionalDependencyNotAvailable:                                                     │
│    36 │   from .utils.dummy_pt_objects import *  # noqa F403                                     │
│    37 else:                                                                                      │
│ ❱  38 │   from .models import (                                                                  │
│    39 │   │   AsymmetricAutoencoderKL,                                                           │
│    40 │   │   AutoencoderKL,                                                                     │
│    41 │   │   ControlNetModel,                                                                   │
│                                                                                                  │
│ /home/ray/anaconda3/lib/python3.10/site-packages/diffusers/models/__init__.py:35 in <module>     │
│                                                                                                  │
│   32 │   from .vq_model import VQModel                                                           │
│   33                                                                                             │
│   34 if is_flax_available():                                                                     │
│ ❱ 35 │   from .controlnet_flax import FlaxControlNetModel                                        │
│   36 │   from .unet_2d_condition_flax import FlaxUNet2DConditionModel                            │
│   37 │   from .vae_flax import FlaxAutoencoderKL                                                 │
│   38                                                                                             │
│                                                                                                  │
│ /home/ray/anaconda3/lib/python3.10/site-packages/diffusers/models/controlnet_flax.py:16 in       │
│ <module>                                                                                         │
│                                                                                                  │
│    13 # limitations under the License.                                                           │
│    14 from typing import Optional, Tuple, Union                                                  │
│    15                                                                                            │
│ ❱  16 import flax                                                                                │
│    17 import flax.linen as nn                                                                    │
│    18 import jax                                                                                 │
│    19 import jax.numpy as jnp                                                                    │
│                                                                                                  │
│ /home/ray/anaconda3/lib/python3.10/site-packages/flax/__init__.py:18 in <module>                 │
│                                                                                                  │
│   15                                                                                             │
│   16 """Flax API."""                                                                             │
│   17                                                                                             │
│ ❱ 18 from .configurations import (                                                               │
│   19 │   config as config,                                                                       │
│   20 )                                                                                           │
│   21                                                                                             │
│                                                                                                  │
│ /home/ray/anaconda3/lib/python3.10/site-packages/flax/configurations.py:92 in <module>           │
│                                                                                                  │
│    89 # Whether to use the lazy rng implementation.                                              │
│    90 flax_lazy_rng = static_bool_env('FLAX_LAZY_RNG', True)                                     │
│    91                                                                                            │
│ ❱  92 flax_filter_frames = define_bool_state(                                                    │
│    93 │   name='filter_frames',                                                                  │
│    94 │   default=True,                                                                          │
│    95 │   help=('Whether to hide flax-internal stack frames from tracebacks.'))                  │
│                                                                                                  │
│ /home/ray/anaconda3/lib/python3.10/site-packages/flax/configurations.py:42 in define_bool_state  │
│                                                                                                  │
│    39   'FLAX_<UPPERCASE_NAME>'. JAX config ensures that the flag can be overwritten             │
│    40   on runtime with `flax.config.update('flax_<config_name>', <value>)`.                     │
│    41   """                                                                                      │
│ ❱  42   return jax_config.define_bool_state('flax_' + name, default, help)                       │
│    43                                                                                            │
│    44                                                                                            │
│    45 def static_bool_env(varname: str, default: bool) -> bool:                                  │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AttributeError: 'Config' object has no attribute 'define_bool_state'

Related issue number

Checks

  • [ ] I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • [ ] I've run scripts/format.sh to lint the changes in this PR.
  • [ ] I've included any doc changes needed for https://docs.ray.io/en/master/.
    • [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in doc/source/tune/api/ under the corresponding .rst file.
  • [ ] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • [ ] Unit tests
    • [ ] Release tests
    • [ ] This PR is not tested :(

matthewdeng avatar May 16 '24 17:05 matthewdeng