mesh-transformer-jax icon indicating copy to clipboard operation
mesh-transformer-jax copied to clipboard

Colab Demo Notebook Not Working

Open CircuitGuy opened this issue 3 years ago • 11 comments

I launched the Colab notebook to try and demo this model.

There's a section that starts with: Sometimes the next step errors for some reason, just run it again ¯\_(ツ)_/¯

That's fine and all, except running it multiple times didn't help. To try and resolve some errors, I tried to: !pip install optax transformers ray

That got me closer, but it errors out with:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-18-a22d9a83aa66> in <module>()
      7 import transformers
      8 
----> 9 from mesh_transformer.checkpoint import read_ckpt_lowmem
     10 from mesh_transformer.sampling import nucleaus_sample
     11 from mesh_transformer.transformer_shard import CausalTransformer

2 frames
/usr/lib/python3.7/typing.py in __new__(cls, *args, **kwds)
    308                 isinstance(args[1], tuple)):
    309             # Close enough.
--> 310             raise TypeError(f"Cannot subclass {cls!r}")
    311         return super().__new__(cls)
    312 

TypeError: Cannot subclass <class 'typing._SpecialForm'>

CircuitGuy avatar Dec 07 '21 04:12 CircuitGuy

I am witnessing the same error. Would appreciate if somebody can help

sharaku17 avatar Dec 08 '21 19:12 sharaku17

Yes, I'm also getting this exact error. Any feedback appreciated!

texturejc avatar Dec 08 '21 22:12 texturejc

I was able to solve this, let the first cell run where all requirements from requirements.txt are installed and afterwards run the two following pip installs:

!pip install optax==0.0.9 transformers dm-haiku einops

and

!pip install ray

after installing this I was able to run the following cells in the notebook without any problems

See: #161 Aspie96 Comment

sharaku17 avatar Dec 09 '21 14:12 sharaku17

I'm getting this exact error and nothing has resolved it, including @Aspie96's helpful installation info.

I get two errors at this stage.

Sometimes the next step errors for some reason, just run it again ¯_(ツ)_/¯

---------------------------------------------------------------------------

AttributeError                            Traceback (most recent call last)

[<ipython-input-16-a22d9a83aa66>](https://localhost:8080/#) in <module>()
      4 from jax.experimental import maps
      5 import numpy as np
----> 6 import optax
      7 import transformers
      8
---------------------------------------------------------------------------
[/usr/local/lib/python3.7/dist-packages/chex/_src/pytypes.py](https://localhost:8080/#) in <module>()
     34 Scalar = Union[float, int]
     35 Numeric = Union[Array, Scalar]
---> 36 PRNGKey = jax.random.KeyArray
     37 PyTreeDef = type(jax.tree_structure(None))
     38 Shape = jax.core.Shape

AttributeError: module 'jax.random' has no attribute 'KeyArray'

joan0fsnark avatar May 12 '22 03:05 joan0fsnark

getting error on,

ModuleNotFoundError                       Traceback (most recent call last)
[<ipython-input-2-a22d9a83aa66>](https://localhost:8080/#) in <module>()
      7 import transformers
      8 
----> 9 from mesh_transformer.checkpoint import read_ckpt_lowmem
     10 from mesh_transformer.sampling import nucleaus_sample
     11 from mesh_transformer.transformer_shard import CausalTransformer

ModuleNotFoundError: No module named 'mesh_transformer'

---------------------------------------------------------------------------
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.

Raval-Arth avatar May 14 '22 14:05 Raval-Arth

This whole project has been quite the waste of time.

ruze00 avatar Aug 09 '22 14:08 ruze00

I beg to differ. It takes some effort to get the right mix of dependencies, but when it works it works quite well.

Here is what I use in my Colab Pro to bootstrap the model.

These are the packages that work and bypass requirements.txt (which is outdated/quirky on dependencies)

!pip install numpy~=1.21.0 !pip install typing-extensions~=3.7.4 !pip install tqdm>=4.45.0 !pip install wandb>=0.11.2 !pip install einops~=0.3.0 !pip install requests~=2.25.1 !pip install fabric~=2.6.0 !pip install optax==0.0.9 !pip install dm-haiku==0.0.5 !pip install git+https://github.com/EleutherAI/lm-evaluation-harness/ !pip install ray[default]==1.4.1 !pip install jaxlib~=0.1.68 !pip install jax~=0.2.12 !pip install Flask~=1.1.2 !pip install cloudpickle~=1.3.0 !pip install tensorflow-cpu~=2.7.0 !pip install google-cloud-storage~=1.36.2 !pip install transformers !pip install smart_open[gcs] !pip install func_timeout !pip install ftfy !pip install fastapi !pip install uvicorn !pip install lm_dataformat !pip install pathy

!pip install mesh-transformer-jax/ jax==0.2.12 tensorflow==2.7.0 chex==0.1.2 jaxlib==0.1.68

JohnnyOpcode avatar Aug 09 '22 14:08 JohnnyOpcode

I have successfully installed and run the application using the recommended dependencies, two packages required a higher version level.

jweber00 avatar Oct 29 '22 14:10 jweber00

Thank you so much @JohnnyOpcode! I was going nuts trying to figure out all the errors. Using your list, I was able to proceed and run the notebook to completion!

hovanesgasparian avatar Nov 01 '22 00:11 hovanesgasparian

Thank you so much @JohnnyOpcode! I was going nuts trying to figure out all the errors. Using your list, I was able to proceed and run the notebook to completion!

There is another recent issue posted where wheel versions are further tweaked. GPT-J on JAX is brilliant work and just needs some love and attention when it comes to dependencies. Have a look at that and maybe it will reap further rewards.

Bravo to @kingoflolz on this outstanding bit (pun) of work.

JohnnyOpcode avatar Nov 01 '22 05:11 JohnnyOpcode

Well, it's obvious that it needs attention. It's all anyone needs.

Aspie96 avatar Nov 01 '22 05:11 Aspie96