transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Bug in t5x to PyTorch weights conversion script

Open rinapch opened this issue 2 years ago • 5 comments

System Info

transformers version: 4.26.1 Platform: Ubuntu 20.04.5 LTS (Focal Fossa) Python version: 3.8 Huggingface_hub version: 0.12.1
PyTorch version (GPU?): 1.13.1 (True) Tensorflow version (GPU?): not installed (NA) Flax version (CPU?/GPU?/TPU?): 0.6.6 (GPU) Jax version: 0.4.5
JaxLib version: 0.4.4

Who can help?

@sgugger

Information

  • [X] The official example scripts
  • [ ] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [ ] My own task or dataset (give details below)

Reproduction

This is the official example for script transformers/models/t5/convert_t5x_checkpoint_to_pytorch.py

  1. gsutil -m cp -r gs://t5-data/pretrained_models/t5x/t5_1_1_small $HOME/
  2. python3 convert_t5x_checkpoint_to_pytorch.py --t5x_checkpoint_path=$HOME/t5_1_1_small --config_file=config.json --pytorch_dump_path=$HOME/t5_1_1_small_pt

Where config.json is a config for t5-small (https://huggingface.co/t5-small/blob/main/config.json)

When running this, I get an error:

Traceback (most recent call last): File "/root/transformers/src/transformers/models/t5/convert_t5x_checkpoint_to_pytorch.py", line 231, in convert_t5x_checkpoint_to_pytorch( File "/root/transformers/src/transformers/models/t5/convert_t5x_checkpoint_to_pytorch.py", line 200, in convert_t5x_checkpoint_to_pytorch load_t5x_weights_in_t5(model, config, t5x_checkpoint_path, is_encoder_only) File "/root/transformers/src/transformers/models/t5/convert_t5x_checkpoint_to_pytorch.py", line 181, in load_t5x_weights_in_t5 state_dict = make_state_dict(converted, is_encoder_only) File "/root/transformers/src/transformers/models/t5/convert_t5x_checkpoint_to_pytorch.py", line 160, in make_state_dict state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()]) File "/root/transformers/src/transformers/models/t5/convert_t5x_checkpoint_to_pytorch.py", line 160, in state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()]) TypeError: expected np.ndarray (got Array)

This can be fixed easily by importing numpy and changing line 160 to: state_dict = collections.OrderedDict([(k, torch.from_numpy(np.array(v.copy()))) for (k, v) in converted_params.items()])

Expected behavior

After converting v to np.array(v) the script exceutes fine and returns

All model checkpoint weights were used when initializing T5ForConditionalGeneration.

All the weights of T5ForConditionalGeneration were initialized from the model checkpoint at /root/t5_1_1_small_pt. If your task is similar to the task the model of the checkpoint was trained on, you can already use T5ForConditionalGeneration for predictions without further training. loading configuration file /root/t5_1_1_small_pt/generation_config.json Generate config GenerationConfig { "_from_model_config": true, "decoder_start_token_id": 0, "eos_token_id": 1, "pad_token_id": 0, "transformers_version": "4.26.1" }

Done

rinapch avatar Mar 07 '23 19:03 rinapch

cc @ArthurZucker and @younesbelkada

sgugger avatar Mar 07 '23 19:03 sgugger

hello @rinapch Thanks for the issue, we used the same script to convert flan-ul2 and did not face into any issue. Can you share with use the t5x version you used?

younesbelkada avatar Mar 08 '23 08:03 younesbelkada

This is related to an update of jax and jax.numpy. torch.FloatTensor(weights["token_embedder"]["embedding"]) does not work anymore as it was reported. Will have a look as the broader impact this has on our codebase. Thanks for reporting!

ArthurZucker avatar Mar 08 '23 09:03 ArthurZucker

Hey @younesbelkada! As far as I know, t5x do not really release versions (their version.py still states "0.0.0" - https://github.com/google-research/t5x/blob/main/t5x/version.py). I used a clone of their repo to build t5x module, and I cloned it on monday, so the code is up to date

rinapch avatar Mar 08 '23 10:03 rinapch

hi @rinapch can you try:

pip install git+https://github.com/google-research/t5x@45c1a9d02321afeadb43f496de83c52421f52d66

this is the version of t5x that worked fine on my setup

younesbelkada avatar Mar 15 '23 19:03 younesbelkada

Repeated the steps with this version and I get the following error:

File "convert_t5x_checkpoint_to_pytorch.py", line 36, in from t5x import checkpoints File "/root/.cache/pypoetry/virtualenvs/chatbot-JrwxGvoq-py3.8/lib/python3.8/site-packages/t5x/init.py", line 17, in import t5x.adafactor File "/root/.cache/pypoetry/virtualenvs/chatbot-JrwxGvoq-py3.8/lib/python3.8/site-packages/t5x/adafactor.py", line 63, in from t5x import utils File "/root/.cache/pypoetry/virtualenvs/chatbot-JrwxGvoq-py3.8/lib/python3.8/site-packages/t5x/utils.py", line 46, in from t5x import checkpoints File "/root/.cache/pypoetry/virtualenvs/chatbot-JrwxGvoq-py3.8/lib/python3.8/site-packages/t5x/checkpoints.py", line 160, in orbax.checkpoint.utils.register_ts_spec_for_serialization() AttributeError: module 'orbax.checkpoint.utils' has no attribute 'register_ts_spec_for_serialization'

rinapch avatar Mar 21 '23 11:03 rinapch

@rinapch Can you try with: orbax @ git+https://github.com/google/orbax@4ca7a3b61081e91323c89cf09f8c1a53c06cccda ?

pip install git+https://github.com/google/orbax@4ca7a3b61081e91323c89cf09f8c1a53c06cccda

younesbelkada avatar Mar 21 '23 12:03 younesbelkada

This worked, yep!

rinapch avatar Mar 21 '23 12:03 rinapch

Awesome, feel free to close the issue, so the fix was to:

pip install git+https://github.com/google-research/t5x@45c1a9d02321afeadb43f496de83c52421f52d66
pip install git+https://github.com/google/orbax@4ca7a3b61081e91323c89cf09f8c1a53c06cccda

younesbelkada avatar Mar 21 '23 12:03 younesbelkada