mt3 icon indicating copy to clipboard operation
mt3 copied to clipboard

Model Loading Failed in Colab

Open Catoverflow opened this issue 1 year ago • 5 comments

I tried to run mt3 in colab, but it failed. I am not familiar with the DNN libraries so I'm posting steps to reproduce here only.

Steps to Reproduce

  1. Choose T4 GPU in runtime type
  2. Run the cell of Setup Environment It error with
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tf-keras 2.15.1 requires tensorflow<2.16,>=2.15, but you have tensorflow 2.17.0 which is incompatible.
  1. Run the cell of Import and Definitions
  2. Run the cell of Load Model with either mt3 or ismir2021, the notebook errored with:
---------------------------------------------------------------------------

XlaRuntimeError                           Traceback (most recent call last)

[<ipython-input-12-e1bcd991ed4d>](https://localhost:8080/#) in <cell line: 13>()
     11 
     12 log_event('loadModelStart', {'event_category': MODEL})
---> 13 inference_model = InferenceModel(checkpoint_path, MODEL)
     14 log_event('loadModelComplete', {'event_category': MODEL})

13 frames

[<ipython-input-11-30d8629039fb>](https://localhost:8080/#) in __init__(self, checkpoint_path, model_type)
     85 
     86     # Restore from checkpoint.
---> 87     self.restore_from_checkpoint(checkpoint_path)
     88 
     89   @property

[<ipython-input-11-30d8629039fb>](https://localhost:8080/#) in restore_from_checkpoint(self, checkpoint_path)
    120   def restore_from_checkpoint(self, checkpoint_path):
    121     """Restore training state from checkpoint, resets self._predict_fn()."""
--> 122     train_state_initializer = t5x.utils.TrainStateInitializer(
    123       optimizer_def=self.model.optimizer_def,
    124       init_fn=self.model.get_initial_variables,

[/usr/local/lib/python3.10/dist-packages/t5x/utils.py](https://localhost:8080/#) in __init__(self, optimizer_def, init_fn, input_shapes, partitioner, input_types)
   1057     self._partitioner = partitioner
   1058     self.global_train_state_shape = jax.eval_shape(
-> 1059         initialize_train_state, rng=jax.random.PRNGKey(0)
   1060     )
   1061     self.train_state_axes = partitioner.get_mesh_axes(

[/usr/local/lib/python3.10/dist-packages/jax/_src/random.py](https://localhost:8080/#) in PRNGKey(seed, impl)
    231     and ``fold_in``.
    232   """
--> 233   return _return_prng_keys(True, _key('PRNGKey', seed, impl))
    234 
    235 

[/usr/local/lib/python3.10/dist-packages/jax/_src/random.py](https://localhost:8080/#) in _key(ctor_name, seed, impl_spec)
    193         f"{ctor_name} accepts a scalar seed, but was given an array of "
    194         f"shape {np.shape(seed)} != (). Use jax.vmap for batching")
--> 195   return prng.random_seed(seed, impl=impl)
    196 
    197 def key(seed: int | ArrayLike, *,

[/usr/local/lib/python3.10/dist-packages/jax/_src/prng.py](https://localhost:8080/#) in random_seed(seeds, impl)
    531   # use-case of instantiating with Python hashes in X32 mode.
    532   if isinstance(seeds, int):
--> 533     seeds_arr = jnp.asarray(np.int64(seeds))
    534   else:
    535     seeds_arr = jnp.asarray(seeds)

[/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in asarray(a, dtype, order, copy)
   3287   if dtype is not None:
   3288     dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True)  # type: ignore[assignment]
-> 3289   return array(a, dtype=dtype, copy=bool(copy), order=order)
   3290 
   3291 

[/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in array(object, dtype, copy, order, ndmin)
   3212     raise TypeError(f"Unexpected input type for array: {type(object)}")
   3213 
-> 3214   out_array: Array = lax_internal._convert_element_type(
   3215       out, dtype, weak_type=weak_type)
   3216   if ndmin > ndim(out_array):

[/usr/local/lib/python3.10/dist-packages/jax/_src/lax/lax.py](https://localhost:8080/#) in _convert_element_type(operand, new_dtype, weak_type)
    557     return type_cast(Array, operand)
    558   else:
--> 559     return convert_element_type_p.bind(operand, new_dtype=new_dtype,
    560                                        weak_type=bool(weak_type))
    561 

[/usr/local/lib/python3.10/dist-packages/jax/_src/core.py](https://localhost:8080/#) in bind(self, *args, **params)
    414     assert (not config.enable_checks.value or
    415             all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 416     return self.bind_with_trace(find_top_trace(args), args, params)
    417 
    418   def bind_with_trace(self, trace, args, params):

[/usr/local/lib/python3.10/dist-packages/jax/_src/core.py](https://localhost:8080/#) in bind_with_trace(self, trace, args, params)
    418   def bind_with_trace(self, trace, args, params):
    419     with pop_level(trace.level):
--> 420       out = trace.process_primitive(self, map(trace.full_raise, args), params)
    421     return map(full_lower, out) if self.multiple_results else full_lower(out)
    422 

[/usr/local/lib/python3.10/dist-packages/jax/_src/core.py](https://localhost:8080/#) in process_primitive(self, primitive, tracers, params)
    919       return call_impl_with_key_reuse_checks(primitive, primitive.impl, *tracers, **params)
    920     else:
--> 921       return primitive.impl(*tracers, **params)
    922 
    923   def process_call(self, primitive, f, tracers, params):

[/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in apply_primitive(prim, *args, **params)
     85   prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
     86   try:
---> 87     outs = fun(*args)
     88   finally:
     89     lib.jax_jit.swap_thread_local_state_disable_jit(prev)

    [... skipping hidden 15 frame]

[/usr/local/lib/python3.10/dist-packages/jax/_src/compiler.py](https://localhost:8080/#) in backend_compile(backend, module, options, host_callbacks)
    236   # TODO(sharadmv): remove this fallback when all backends allow `compile`
    237   # to take in `host_callbacks`
--> 238   return backend.compile(built_c, compile_options=options)
    239 
    240 def compile_or_get_cached(

XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

Catoverflow avatar Jul 08 '24 12:07 Catoverflow

same issue

ntamotsu avatar Jul 13 '24 17:07 ntamotsu

same problem here

goel-raghav avatar Jul 16 '24 16:07 goel-raghav

seems like either changing to CPU or changing:

!python3 -m pip install jax[cuda12_local] nest-asyncio pyfluidsynth==1.3.0 -e . -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html to !python3 -m pip install nest-asyncio pyfluidsynth==1.3.0 -e .

fixes the problem. Not sure which one exactly because I ran out of colab GPU hours or something.

goel-raghav avatar Jul 17 '24 19:07 goel-raghav

@goel-raghav Yes, after changing the code it says WARNING:jax._src.xla_bridge:An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu., and the model loading succeeded.

Catoverflow avatar Jul 19 '24 12:07 Catoverflow

https://github.com/magenta/mt3/pull/160 works for GPU, which is much faster than CPU.

laqieer avatar Jul 27 '24 15:07 laqieer