donut icon indicating copy to clipboard operation
donut copied to clipboard

Issues running CORD inference

Open flauted opened this issue 1 year ago • 3 comments

I'm running into RuntimeError: Error(s) in loading state_dict for DonutModel: size mismatch and NotImplementedError: Make sure `_init_weights` is implemented for <class 'donut.model.DonutModel'> errors while trying to run the provided pre-trained donut model on CORD. Possibly related: #184 and #29

For me, the issues appeared regardless of whether I used the huggingface download (--pretrained_model_name_or_path naver-clova-ix/donut-base-finetuned-cord-v2) or cloned the official branch and specified the path.

I was using the pypi package (pip install donut-python), not installing from source.

For me, it came down to the timm version. By default, timm=0.9.2 is installed currently. But it appears that timm==0.6.13 was the last non-prerelease version of timm currently on pypi that works. timm==0.9.0 gives ``ImportError: cannot import name 'Final' from 'typing' and the size mismatch error appears starting with timm==0.9.1

So minimal steps to fix were:

conda create -n donutv3 python=3.7 pip
conda activate donutv3
pip install donut-python
pip install timm==0.5.4  # up to 0.6.16
python --dataset_name_or_path naver-clova-ix/cord-v2 --pretrained_model_name_or_path naver-clova-ix/donut-base-finetuned-cord-v2 --save_path ./result/output.json
# works
git clone -b official --single-branch
python --dataset_name_or_path naver-clova-ix/cord-v2 --pretrained_model_name_or_path ./donut-base-finetuned-cord-v2 --save_path ./result/output.json
# works

To resolve these issues, I think the timm version should be restricted in the here: However, I haven't tested anything else in the codebase (other downstream tasks or training).

At first, I tried changing the version of transformers down from the current default of 4.29.1 to 4.25.1 as suggested in this comment. That changed the error from the size mismatch to the _init_weights error.

Current output from ``conda list``: EDIT: will give poor performance - see below
# packages in environment at /home/dflaute/miniconda3/envs/donutv3:
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main  
_openmp_mutex             5.1                       1_gnu  
aiohttp                   3.8.4                    pypi_0    pypi
aiosignal                 1.3.1                    pypi_0    pypi
async-timeout             4.0.2                    pypi_0    pypi
asynctest                 0.13.0                   pypi_0    pypi
attrs                     23.1.0                   pypi_0    pypi
ca-certificates           2023.01.10           h06a4308_0  
certifi                   2022.12.7        py37h06a4308_0  
charset-normalizer        3.1.0                    pypi_0    pypi
click                     8.1.3                    pypi_0    pypi
datasets                  2.12.0                   pypi_0    pypi
dill                      0.3.6                    pypi_0    pypi
donut-python              1.0.9                    pypi_0    pypi
filelock                  3.12.0                   pypi_0    pypi
frozenlist                1.3.3                    pypi_0    pypi
fsspec                    2023.1.0                 pypi_0    pypi
huggingface-hub           0.14.1                   pypi_0    pypi
idna                      3.4                      pypi_0    pypi
importlib-metadata        6.6.0                    pypi_0    pypi
joblib                    1.2.0                    pypi_0    pypi
ld_impl_linux-64          2.38                 h1181459_1  
libffi                    3.4.4                h6a678d5_0  
libgcc-ng                 11.2.0               h1234567_1  
libgomp                   11.2.0               h1234567_1  
libstdcxx-ng              11.2.0               h1234567_1  
lightning-utilities       0.8.0                    pypi_0    pypi
multidict                 6.0.4                    pypi_0    pypi
multiprocess              0.70.14                  pypi_0    pypi
munch                     3.0.0                    pypi_0    pypi
ncurses                   6.4                  h6a678d5_0  
nltk                      3.8.1                    pypi_0    pypi
numpy                     1.21.6                   pypi_0    pypi
nvidia-cublas-cu11               pypi_0    pypi
nvidia-cuda-nvrtc-cu11    11.7.99                  pypi_0    pypi
nvidia-cuda-runtime-cu11  11.7.99                  pypi_0    pypi
nvidia-cudnn-cu11                 pypi_0    pypi
openssl                   1.1.1t               h7f8727e_0  
packaging                 23.1                     pypi_0    pypi
pandas                    1.3.5                    pypi_0    pypi
pillow                    9.5.0                    pypi_0    pypi
pip                       22.3.1           py37h06a4308_0  
pyarrow                   12.0.0                   pypi_0    pypi
python                    3.7.16               h7a1cb2a_0  
python-dateutil           2.8.2                    pypi_0    pypi
pytorch-lightning         1.9.5                    pypi_0    pypi
pytz                      2023.3                   pypi_0    pypi
pyyaml                    6.0                      pypi_0    pypi
readline                  8.2                  h5eee18b_0  
regex                     2023.5.5                 pypi_0    pypi
requests                  2.30.0                   pypi_0    pypi
responses                 0.18.0                   pypi_0    pypi
ruamel-yaml               0.17.26                  pypi_0    pypi
ruamel-yaml-clib          0.2.7                    pypi_0    pypi
safetensors               0.3.1                    pypi_0    pypi
sconf                     0.2.5                    pypi_0    pypi
sentencepiece             0.1.99                   pypi_0    pypi
setuptools                65.6.3           py37h06a4308_0  
six                       1.16.0                   pypi_0    pypi
sqlite                    3.41.2               h5eee18b_0  
timm                      0.5.4                    pypi_0    pypi
tk                        8.6.12               h1ccaba5_0  
tokenizers                0.13.3                   pypi_0    pypi
torch                     1.13.1                   pypi_0    pypi
torchmetrics              0.11.4                   pypi_0    pypi
torchvision               0.14.1                   pypi_0    pypi
tqdm                      4.65.0                   pypi_0    pypi
transformers              4.29.1                   pypi_0    pypi
typing-extensions         4.5.0                    pypi_0    pypi
urllib3                   2.0.2                    pypi_0    pypi
wheel                     0.38.4           py37h06a4308_0  
xxhash                    3.2.0                    pypi_0    pypi
xz                        5.4.2                h5eee18b_0  
yarl                      1.9.2                    pypi_0    pypi
zipp                      3.15.0                   pypi_0    pypi
zlib                      1.2.13               h5eee18b_0  
zss                       1.2.0                    pypi_0    pypi

Update: In fact, the test results are terrible with that install:

Total number of samples: 100, Tree Edit Distance (TED) based accuracy score: 0.17447007822672653, F1 accuracy score: 0.12529832935560858

transformers==4.25.1 fixes this:

Total number of samples: 100, Tree Edit Distance (TED) based accuracy score: 0.910367621155524, F1 accuracy score: 0.8373353989155693

flauted avatar May 16 '23 15:05 flauted

Thank you. You saved my life!

YuffieHuang avatar May 16 '23 17:05 YuffieHuang

Thank you. You saved my life!

PirateX0 avatar May 17 '23 02:05 PirateX0

It's hard to believe that throughout all of the documentation that I searched through, I cannot find this fix anywhere else - even on the Google Colabs, which are the most recently updated docs. Thank you for this.

cocoa004 avatar Jun 25 '24 22:06 cocoa004