mesh-transformer-jax icon indicating copy to clipboard operation
mesh-transformer-jax copied to clipboard

TypeError: Cannot subclass <class 'typing._SpecialForm'> while fine tuning

Open samyakai opened this issue 2 years ago • 9 comments

I am trying to fine tune gpt-j on custom data using TPU. When I try to run the "device_train.py" file using the mentioned command: "python3 device_train.py --config=YOUR_CONFIG.json --tune-model-path=gs://YOUR-BUCKET/step_383500/", I get this error:

Traceback (most recent call last): File "device_train.py", line 13, in from mesh_transformer import util File "/home/shreyjain/mesh-transformer-jax/mesh_transformer/util.py", line 36, in class ClipByGlobalNormState(OptState): File "/usr/lib/python3.8/typing.py", line 317, in new raise TypeError(f"Cannot subclass {cls!r}") TypeError: Cannot subclass <class 'typing._SpecialForm'>

OS = Ubuntu 20.04 TPU V3-8 python version = 3.8 and 3.7 both give the error

I have no idea what this error means. Any help would be appreciated! Thank you.

samyakai avatar Apr 23 '22 04:04 samyakai

Getting the same issue not able to solve Please help us Thank you

jagruti-samyak avatar Apr 23 '22 12:04 jagruti-samyak

Maybe this works.

pip install dm-haiku==0.0.5 and put optax back to the default version.

mosmos6 avatar May 05 '22 08:05 mosmos6

@mosmos6 Nope it doesn't work

shrey10926 avatar May 08 '22 05:05 shrey10926

https://github.com/kingoflolz/mesh-transformer-jax/issues/202#issuecomment-1050887576 Please follow this solution! It works.

anon-mouse-1 avatar May 09 '22 08:05 anon-mouse-1

@anon-mouse-1 which v2 version of TPU should i use? There are 2 options for TPU namely TPU VM architecture and tpu node architecture.

jagruti-samyak avatar May 09 '22 09:05 jagruti-samyak

After the 5th error I just gave up on this notebook.

Tylersuard avatar May 14 '22 07:05 Tylersuard

@Tylersuard Yes. Also no one is providing a solution to the errors which is a shame as I really want to train on TPU as opposed to a GPU

samyakai avatar May 16 '22 04:05 samyakai

Downgrading optax worked for me to get rid of this error.

pip install optax==0.0.9

dhruv2601 avatar Jun 13 '22 15:06 dhruv2601

#202 (comment) Please follow this solution! It works.

In addition to this, pip install chex==0.1.2 pip install jaxlib==0.1.74 pip install dm-haiku==0.0.5

and it worked for me.

mosmos6 avatar Jul 18 '22 13:07 mosmos6