mesh-transformer-jax
mesh-transformer-jax copied to clipboard
TypeError: Cannot subclass <class 'typing._SpecialForm'> while fine tuning
I am trying to fine tune gpt-j on custom data using TPU. When I try to run the "device_train.py" file using the mentioned command: "python3 device_train.py --config=YOUR_CONFIG.json --tune-model-path=gs://YOUR-BUCKET/step_383500/", I get this error:
Traceback (most recent call last):
File "device_train.py", line 13, in
OS = Ubuntu 20.04 TPU V3-8 python version = 3.8 and 3.7 both give the error
I have no idea what this error means. Any help would be appreciated! Thank you.
Getting the same issue not able to solve Please help us Thank you
Maybe this works.
pip install dm-haiku==0.0.5
and put optax back to the default version.
@mosmos6 Nope it doesn't work
https://github.com/kingoflolz/mesh-transformer-jax/issues/202#issuecomment-1050887576 Please follow this solution! It works.
@anon-mouse-1 which v2 version of TPU should i use? There are 2 options for TPU namely TPU VM architecture and tpu node architecture.
After the 5th error I just gave up on this notebook.
@Tylersuard Yes. Also no one is providing a solution to the errors which is a shame as I really want to train on TPU as opposed to a GPU
Downgrading optax worked for me to get rid of this error.
pip install optax==0.0.9
#202 (comment) Please follow this solution! It works.
In addition to this,
pip install chex==0.1.2
pip install jaxlib==0.1.74
pip install dm-haiku==0.0.5
and it worked for me.