AttributeError: jax.tree_map was removed in JAX v0.6.0: use jax.tree.map
Full error message:
from underthesea.pipeline.say import say say("xin chào") Traceback (most recent call last): File "
", line 1, in File "/opt/anaconda3/envs/undersea/lib/python3.11/site-packages/underthesea/pipeline/say/init.py", line 31, in say y = text_to_speech(text) ^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/undersea/lib/python3.11/site-packages/underthesea/pipeline/say/init.py", line 19, in text_to_speech mel = text2mel( ^^^^^^^^^ File "/opt/anaconda3/envs/undersea/lib/python3.11/site-packages/underthesea/pipeline/say/viettts_/nat/text2mel.py", line 96, in text2mel durations = predict_duration(tokens, duration_ckpt) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/undersea/lib/python3.11/site-packages/underthesea/pipeline/say/viettts_/nat/text2mel.py", line 36, in predict_duration return forward_fn(dic["params"], dic["aux"], dic["rng"], x)[0] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/undersea/lib/python3.11/site-packages/haiku/src/transform.py", line 456, in apply_fn out = f(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/undersea/lib/python3.11/site-packages/underthesea/pipeline/say/viettts/nat/text2mel.py", line 26, in fwd_ return DurationModel(is_training=False)(x) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/undersea/lib/python3.11/site-packages/haiku/_src/module.py", line 464, in wrapped out = f(*args, **kwargs) ^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/undersea/lib/python3.11/contextlib.py", line 81, in inner return func(*args, **kwds) ^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/undersea/lib/python3.11/site-packages/haiku/src/module.py", line 305, in run_interceptors return bound_method(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/undersea/lib/python3.11/site-packages/underthesea/pipeline/say/viettts/nat/model.py", line 67, in call x = self.encoder(inputs.phonemes, inputs.lengths) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/undersea/lib/python3.11/site-packages/haiku/_src/module.py", line 464, in wrapped out = f(*args, **kwargs) ^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/undersea/lib/python3.11/contextlib.py", line 81, in inner return func(*args, **kwds) ^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/undersea/lib/python3.11/site-packages/haiku/src/module.py", line 305, in run_interceptors return bound_method(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/undersea/lib/python3.11/site-packages/underthesea/pipeline/say/viettts/nat/model.py", line 41, in call x_bwd, mask_bwd = jax.tree_map(lambda x: jnp.flip(x, axis=1), (x, mask)) ^^^^^^^^^^^^ File "/opt/anaconda3/envs/undersea/lib/python3.11/site-packages/jax/_src/deprecations.py", line 54, in getattr raise AttributeError(message) AttributeError: jax.tree_map was removed in JAX v0.6.0: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.