transformers
transformers copied to clipboard
Bug in t5x to PyTorch weights conversion script
System Info
transformers version: 4.26.1
Platform: Ubuntu 20.04.5 LTS (Focal Fossa)
Python version: 3.8
Huggingface_hub version: 0.12.1
PyTorch version (GPU?): 1.13.1 (True)
Tensorflow version (GPU?): not installed (NA)
Flax version (CPU?/GPU?/TPU?): 0.6.6 (GPU)
Jax version: 0.4.5
JaxLib version: 0.4.4
Who can help?
@sgugger
Information
- [X] The official example scripts
- [ ] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [ ] My own task or dataset (give details below)
Reproduction
This is the official example for script transformers/models/t5/convert_t5x_checkpoint_to_pytorch.py
gsutil -m cp -r gs://t5-data/pretrained_models/t5x/t5_1_1_small $HOME/python3 convert_t5x_checkpoint_to_pytorch.py --t5x_checkpoint_path=$HOME/t5_1_1_small --config_file=config.json --pytorch_dump_path=$HOME/t5_1_1_small_pt
Where config.json is a config for t5-small (https://huggingface.co/t5-small/blob/main/config.json)
When running this, I get an error:
Traceback (most recent call last): File "/root/transformers/src/transformers/models/t5/convert_t5x_checkpoint_to_pytorch.py", line 231, in
convert_t5x_checkpoint_to_pytorch( File "/root/transformers/src/transformers/models/t5/convert_t5x_checkpoint_to_pytorch.py", line 200, in convert_t5x_checkpoint_to_pytorch load_t5x_weights_in_t5(model, config, t5x_checkpoint_path, is_encoder_only) File "/root/transformers/src/transformers/models/t5/convert_t5x_checkpoint_to_pytorch.py", line 181, in load_t5x_weights_in_t5 state_dict = make_state_dict(converted, is_encoder_only) File "/root/transformers/src/transformers/models/t5/convert_t5x_checkpoint_to_pytorch.py", line 160, in make_state_dict state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()]) File "/root/transformers/src/transformers/models/t5/convert_t5x_checkpoint_to_pytorch.py", line 160, in state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()]) TypeError: expected np.ndarray (got Array)
This can be fixed easily by importing numpy and changing line 160 to:
state_dict = collections.OrderedDict([(k, torch.from_numpy(np.array(v.copy()))) for (k, v) in converted_params.items()])
Expected behavior
After converting v to np.array(v) the script exceutes fine and returns
All model checkpoint weights were used when initializing T5ForConditionalGeneration.
All the weights of T5ForConditionalGeneration were initialized from the model checkpoint at /root/t5_1_1_small_pt. If your task is similar to the task the model of the checkpoint was trained on, you can already use T5ForConditionalGeneration for predictions without further training. loading configuration file /root/t5_1_1_small_pt/generation_config.json Generate config GenerationConfig { "_from_model_config": true, "decoder_start_token_id": 0, "eos_token_id": 1, "pad_token_id": 0, "transformers_version": "4.26.1" }
Done
cc @ArthurZucker and @younesbelkada
hello @rinapch
Thanks for the issue,
we used the same script to convert flan-ul2 and did not face into any issue. Can you share with use the t5x version you used?
This is related to an update of jax and jax.numpy. torch.FloatTensor(weights["token_embedder"]["embedding"]) does not work anymore as it was reported. Will have a look as the broader impact this has on our codebase. Thanks for reporting!
Hey @younesbelkada! As far as I know, t5x do not really release versions (their version.py still states "0.0.0" - https://github.com/google-research/t5x/blob/main/t5x/version.py). I used a clone of their repo to build t5x module, and I cloned it on monday, so the code is up to date
hi @rinapch can you try:
pip install git+https://github.com/google-research/t5x@45c1a9d02321afeadb43f496de83c52421f52d66
this is the version of t5x that worked fine on my setup
Repeated the steps with this version and I get the following error:
File "convert_t5x_checkpoint_to_pytorch.py", line 36, in
from t5x import checkpoints File "/root/.cache/pypoetry/virtualenvs/chatbot-JrwxGvoq-py3.8/lib/python3.8/site-packages/t5x/init.py", line 17, in import t5x.adafactor File "/root/.cache/pypoetry/virtualenvs/chatbot-JrwxGvoq-py3.8/lib/python3.8/site-packages/t5x/adafactor.py", line 63, in from t5x import utils File "/root/.cache/pypoetry/virtualenvs/chatbot-JrwxGvoq-py3.8/lib/python3.8/site-packages/t5x/utils.py", line 46, in from t5x import checkpoints File "/root/.cache/pypoetry/virtualenvs/chatbot-JrwxGvoq-py3.8/lib/python3.8/site-packages/t5x/checkpoints.py", line 160, in orbax.checkpoint.utils.register_ts_spec_for_serialization() AttributeError: module 'orbax.checkpoint.utils' has no attribute 'register_ts_spec_for_serialization'
@rinapch
Can you try with: orbax @ git+https://github.com/google/orbax@4ca7a3b61081e91323c89cf09f8c1a53c06cccda ?
pip install git+https://github.com/google/orbax@4ca7a3b61081e91323c89cf09f8c1a53c06cccda
This worked, yep!
Awesome, feel free to close the issue, so the fix was to:
pip install git+https://github.com/google-research/t5x@45c1a9d02321afeadb43f496de83c52421f52d66
pip install git+https://github.com/google/orbax@4ca7a3b61081e91323c89cf09f8c1a53c06cccda