Model Loading Failed in Colab
I tried to run mt3 in colab, but it failed. I am not familiar with the DNN libraries so I'm posting steps to reproduce here only.
Steps to Reproduce
- Choose
T4 GPUin runtime type - Run the cell of Setup Environment It error with
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tf-keras 2.15.1 requires tensorflow<2.16,>=2.15, but you have tensorflow 2.17.0 which is incompatible.
- Run the cell of Import and Definitions
- Run the cell of Load Model with either
mt3orismir2021, the notebook errored with:
---------------------------------------------------------------------------
XlaRuntimeError Traceback (most recent call last)
[<ipython-input-12-e1bcd991ed4d>](https://localhost:8080/#) in <cell line: 13>()
11
12 log_event('loadModelStart', {'event_category': MODEL})
---> 13 inference_model = InferenceModel(checkpoint_path, MODEL)
14 log_event('loadModelComplete', {'event_category': MODEL})
13 frames
[<ipython-input-11-30d8629039fb>](https://localhost:8080/#) in __init__(self, checkpoint_path, model_type)
85
86 # Restore from checkpoint.
---> 87 self.restore_from_checkpoint(checkpoint_path)
88
89 @property
[<ipython-input-11-30d8629039fb>](https://localhost:8080/#) in restore_from_checkpoint(self, checkpoint_path)
120 def restore_from_checkpoint(self, checkpoint_path):
121 """Restore training state from checkpoint, resets self._predict_fn()."""
--> 122 train_state_initializer = t5x.utils.TrainStateInitializer(
123 optimizer_def=self.model.optimizer_def,
124 init_fn=self.model.get_initial_variables,
[/usr/local/lib/python3.10/dist-packages/t5x/utils.py](https://localhost:8080/#) in __init__(self, optimizer_def, init_fn, input_shapes, partitioner, input_types)
1057 self._partitioner = partitioner
1058 self.global_train_state_shape = jax.eval_shape(
-> 1059 initialize_train_state, rng=jax.random.PRNGKey(0)
1060 )
1061 self.train_state_axes = partitioner.get_mesh_axes(
[/usr/local/lib/python3.10/dist-packages/jax/_src/random.py](https://localhost:8080/#) in PRNGKey(seed, impl)
231 and ``fold_in``.
232 """
--> 233 return _return_prng_keys(True, _key('PRNGKey', seed, impl))
234
235
[/usr/local/lib/python3.10/dist-packages/jax/_src/random.py](https://localhost:8080/#) in _key(ctor_name, seed, impl_spec)
193 f"{ctor_name} accepts a scalar seed, but was given an array of "
194 f"shape {np.shape(seed)} != (). Use jax.vmap for batching")
--> 195 return prng.random_seed(seed, impl=impl)
196
197 def key(seed: int | ArrayLike, *,
[/usr/local/lib/python3.10/dist-packages/jax/_src/prng.py](https://localhost:8080/#) in random_seed(seeds, impl)
531 # use-case of instantiating with Python hashes in X32 mode.
532 if isinstance(seeds, int):
--> 533 seeds_arr = jnp.asarray(np.int64(seeds))
534 else:
535 seeds_arr = jnp.asarray(seeds)
[/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in asarray(a, dtype, order, copy)
3287 if dtype is not None:
3288 dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment]
-> 3289 return array(a, dtype=dtype, copy=bool(copy), order=order)
3290
3291
[/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in array(object, dtype, copy, order, ndmin)
3212 raise TypeError(f"Unexpected input type for array: {type(object)}")
3213
-> 3214 out_array: Array = lax_internal._convert_element_type(
3215 out, dtype, weak_type=weak_type)
3216 if ndmin > ndim(out_array):
[/usr/local/lib/python3.10/dist-packages/jax/_src/lax/lax.py](https://localhost:8080/#) in _convert_element_type(operand, new_dtype, weak_type)
557 return type_cast(Array, operand)
558 else:
--> 559 return convert_element_type_p.bind(operand, new_dtype=new_dtype,
560 weak_type=bool(weak_type))
561
[/usr/local/lib/python3.10/dist-packages/jax/_src/core.py](https://localhost:8080/#) in bind(self, *args, **params)
414 assert (not config.enable_checks.value or
415 all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 416 return self.bind_with_trace(find_top_trace(args), args, params)
417
418 def bind_with_trace(self, trace, args, params):
[/usr/local/lib/python3.10/dist-packages/jax/_src/core.py](https://localhost:8080/#) in bind_with_trace(self, trace, args, params)
418 def bind_with_trace(self, trace, args, params):
419 with pop_level(trace.level):
--> 420 out = trace.process_primitive(self, map(trace.full_raise, args), params)
421 return map(full_lower, out) if self.multiple_results else full_lower(out)
422
[/usr/local/lib/python3.10/dist-packages/jax/_src/core.py](https://localhost:8080/#) in process_primitive(self, primitive, tracers, params)
919 return call_impl_with_key_reuse_checks(primitive, primitive.impl, *tracers, **params)
920 else:
--> 921 return primitive.impl(*tracers, **params)
922
923 def process_call(self, primitive, f, tracers, params):
[/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in apply_primitive(prim, *args, **params)
85 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
86 try:
---> 87 outs = fun(*args)
88 finally:
89 lib.jax_jit.swap_thread_local_state_disable_jit(prev)
[... skipping hidden 15 frame]
[/usr/local/lib/python3.10/dist-packages/jax/_src/compiler.py](https://localhost:8080/#) in backend_compile(backend, module, options, host_callbacks)
236 # TODO(sharadmv): remove this fallback when all backends allow `compile`
237 # to take in `host_callbacks`
--> 238 return backend.compile(built_c, compile_options=options)
239
240 def compile_or_get_cached(
XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
same issue
same problem here
seems like either changing to CPU or changing:
!python3 -m pip install jax[cuda12_local] nest-asyncio pyfluidsynth==1.3.0 -e . -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html to !python3 -m pip install nest-asyncio pyfluidsynth==1.3.0 -e .
fixes the problem. Not sure which one exactly because I ran out of colab GPU hours or something.
@goel-raghav Yes, after changing the code it says WARNING:jax._src.xla_bridge:An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu., and the model loading succeeded.
https://github.com/magenta/mt3/pull/160 works for GPU, which is much faster than CPU.