mesh-transformer-jax
mesh-transformer-jax copied to clipboard
Colab Demo Notebook Not Working
I launched the Colab notebook to try and demo this model.
There's a section that starts with: Sometimes the next step errors for some reason, just run it again ¯\_(ツ)_/¯
That's fine and all, except running it multiple times didn't help. To try and resolve some errors, I tried to:
!pip install optax transformers ray
That got me closer, but it errors out with:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-18-a22d9a83aa66> in <module>()
7 import transformers
8
----> 9 from mesh_transformer.checkpoint import read_ckpt_lowmem
10 from mesh_transformer.sampling import nucleaus_sample
11 from mesh_transformer.transformer_shard import CausalTransformer
2 frames
/usr/lib/python3.7/typing.py in __new__(cls, *args, **kwds)
308 isinstance(args[1], tuple)):
309 # Close enough.
--> 310 raise TypeError(f"Cannot subclass {cls!r}")
311 return super().__new__(cls)
312
TypeError: Cannot subclass <class 'typing._SpecialForm'>
I am witnessing the same error. Would appreciate if somebody can help
Yes, I'm also getting this exact error. Any feedback appreciated!
I was able to solve this, let the first cell run where all requirements from requirements.txt are installed and afterwards run the two following pip installs:
!pip install optax==0.0.9 transformers dm-haiku einops
and
!pip install ray
after installing this I was able to run the following cells in the notebook without any problems
See: #161 Aspie96 Comment
I'm getting this exact error and nothing has resolved it, including @Aspie96's helpful installation info.
I get two errors at this stage.
Sometimes the next step errors for some reason, just run it again ¯_(ツ)_/¯
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
[<ipython-input-16-a22d9a83aa66>](https://localhost:8080/#) in <module>()
4 from jax.experimental import maps
5 import numpy as np
----> 6 import optax
7 import transformers
8
---------------------------------------------------------------------------
[/usr/local/lib/python3.7/dist-packages/chex/_src/pytypes.py](https://localhost:8080/#) in <module>()
34 Scalar = Union[float, int]
35 Numeric = Union[Array, Scalar]
---> 36 PRNGKey = jax.random.KeyArray
37 PyTreeDef = type(jax.tree_structure(None))
38 Shape = jax.core.Shape
AttributeError: module 'jax.random' has no attribute 'KeyArray'
getting error on,
ModuleNotFoundError Traceback (most recent call last)
[<ipython-input-2-a22d9a83aa66>](https://localhost:8080/#) in <module>()
7 import transformers
8
----> 9 from mesh_transformer.checkpoint import read_ckpt_lowmem
10 from mesh_transformer.sampling import nucleaus_sample
11 from mesh_transformer.transformer_shard import CausalTransformer
ModuleNotFoundError: No module named 'mesh_transformer'
---------------------------------------------------------------------------
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.
This whole project has been quite the waste of time.
I beg to differ. It takes some effort to get the right mix of dependencies, but when it works it works quite well.
Here is what I use in my Colab Pro to bootstrap the model.
These are the packages that work and bypass requirements.txt (which is outdated/quirky on dependencies)
!pip install numpy~=1.21.0 !pip install typing-extensions~=3.7.4 !pip install tqdm>=4.45.0 !pip install wandb>=0.11.2 !pip install einops~=0.3.0 !pip install requests~=2.25.1 !pip install fabric~=2.6.0 !pip install optax==0.0.9 !pip install dm-haiku==0.0.5 !pip install git+https://github.com/EleutherAI/lm-evaluation-harness/ !pip install ray[default]==1.4.1 !pip install jaxlib~=0.1.68 !pip install jax~=0.2.12 !pip install Flask~=1.1.2 !pip install cloudpickle~=1.3.0 !pip install tensorflow-cpu~=2.7.0 !pip install google-cloud-storage~=1.36.2 !pip install transformers !pip install smart_open[gcs] !pip install func_timeout !pip install ftfy !pip install fastapi !pip install uvicorn !pip install lm_dataformat !pip install pathy
!pip install mesh-transformer-jax/ jax==0.2.12 tensorflow==2.7.0 chex==0.1.2 jaxlib==0.1.68
I have successfully installed and run the application using the recommended dependencies, two packages required a higher version level.
Thank you so much @JohnnyOpcode! I was going nuts trying to figure out all the errors. Using your list, I was able to proceed and run the notebook to completion!
Thank you so much @JohnnyOpcode! I was going nuts trying to figure out all the errors. Using your list, I was able to proceed and run the notebook to completion!
There is another recent issue posted where wheel versions are further tweaked. GPT-J on JAX is brilliant work and just needs some love and attention when it comes to dependencies. Have a look at that and maybe it will reap further rewards.
Bravo to @kingoflolz on this outstanding bit (pun) of work.
Well, it's obvious that it needs attention. It's all anyone needs.