donut
donut copied to clipboard
Issues running CORD inference
I'm running into RuntimeError: Error(s) in loading state_dict for DonutModel: size mismatch
and NotImplementedError: Make sure `_init_weights` is implemented for <class 'donut.model.DonutModel'>
errors while trying to run the provided pre-trained donut model on CORD. Possibly related: #184 and #29
For me, the issues appeared regardless of whether I used the huggingface download (--pretrained_model_name_or_path naver-clova-ix/donut-base-finetuned-cord-v2
) or cloned the official
branch and specified the path.
I was using the pypi package (pip install donut-python
), not installing from source.
For me, it came down to the timm version. By default, timm=0.9.2 is installed currently. But it appears that timm==0.6.13 was the last non-prerelease version of timm currently on pypi that works. timm==0.9.0 gives ``ImportError: cannot import name 'Final' from 'typing' and the size mismatch error appears starting with timm==0.9.1
So minimal steps to fix were:
conda create -n donutv3 python=3.7 pip
conda activate donutv3
pip install donut-python
pip install timm==0.5.4 # up to 0.6.16
python test.py --dataset_name_or_path naver-clova-ix/cord-v2 --pretrained_model_name_or_path naver-clova-ix/donut-base-finetuned-cord-v2 --save_path ./result/output.json
# works
git clone -b official --single-branch https://huggingface.co/naver-clova-ix/donut-base-finetuned-cord-v2
python test.py --dataset_name_or_path naver-clova-ix/cord-v2 --pretrained_model_name_or_path ./donut-base-finetuned-cord-v2 --save_path ./result/output.json
# works
To resolve these issues, I think the timm version should be restricted in the setup.py here: https://github.com/clovaai/donut/blob/master/setup.py#L54 However, I haven't tested anything else in the codebase (other downstream tasks or training).
At first, I tried changing the version of transformers down from the current default of 4.29.1 to 4.25.1 as suggested in this comment. That changed the error from the size mismatch to the _init_weights
error.
Current output from ``conda list``: EDIT: will give poor performance - see below
# packages in environment at /home/dflaute/miniconda3/envs/donutv3:
#
# Name Version Build Channel
_libgcc_mutex 0.1 main
_openmp_mutex 5.1 1_gnu
aiohttp 3.8.4 pypi_0 pypi
aiosignal 1.3.1 pypi_0 pypi
async-timeout 4.0.2 pypi_0 pypi
asynctest 0.13.0 pypi_0 pypi
attrs 23.1.0 pypi_0 pypi
ca-certificates 2023.01.10 h06a4308_0
certifi 2022.12.7 py37h06a4308_0
charset-normalizer 3.1.0 pypi_0 pypi
click 8.1.3 pypi_0 pypi
datasets 2.12.0 pypi_0 pypi
dill 0.3.6 pypi_0 pypi
donut-python 1.0.9 pypi_0 pypi
filelock 3.12.0 pypi_0 pypi
frozenlist 1.3.3 pypi_0 pypi
fsspec 2023.1.0 pypi_0 pypi
huggingface-hub 0.14.1 pypi_0 pypi
idna 3.4 pypi_0 pypi
importlib-metadata 6.6.0 pypi_0 pypi
joblib 1.2.0 pypi_0 pypi
ld_impl_linux-64 2.38 h1181459_1
libffi 3.4.4 h6a678d5_0
libgcc-ng 11.2.0 h1234567_1
libgomp 11.2.0 h1234567_1
libstdcxx-ng 11.2.0 h1234567_1
lightning-utilities 0.8.0 pypi_0 pypi
multidict 6.0.4 pypi_0 pypi
multiprocess 0.70.14 pypi_0 pypi
munch 3.0.0 pypi_0 pypi
ncurses 6.4 h6a678d5_0
nltk 3.8.1 pypi_0 pypi
numpy 1.21.6 pypi_0 pypi
nvidia-cublas-cu11 11.10.3.66 pypi_0 pypi
nvidia-cuda-nvrtc-cu11 11.7.99 pypi_0 pypi
nvidia-cuda-runtime-cu11 11.7.99 pypi_0 pypi
nvidia-cudnn-cu11 8.5.0.96 pypi_0 pypi
openssl 1.1.1t h7f8727e_0
packaging 23.1 pypi_0 pypi
pandas 1.3.5 pypi_0 pypi
pillow 9.5.0 pypi_0 pypi
pip 22.3.1 py37h06a4308_0
pyarrow 12.0.0 pypi_0 pypi
python 3.7.16 h7a1cb2a_0
python-dateutil 2.8.2 pypi_0 pypi
pytorch-lightning 1.9.5 pypi_0 pypi
pytz 2023.3 pypi_0 pypi
pyyaml 6.0 pypi_0 pypi
readline 8.2 h5eee18b_0
regex 2023.5.5 pypi_0 pypi
requests 2.30.0 pypi_0 pypi
responses 0.18.0 pypi_0 pypi
ruamel-yaml 0.17.26 pypi_0 pypi
ruamel-yaml-clib 0.2.7 pypi_0 pypi
safetensors 0.3.1 pypi_0 pypi
sconf 0.2.5 pypi_0 pypi
sentencepiece 0.1.99 pypi_0 pypi
setuptools 65.6.3 py37h06a4308_0
six 1.16.0 pypi_0 pypi
sqlite 3.41.2 h5eee18b_0
timm 0.5.4 pypi_0 pypi
tk 8.6.12 h1ccaba5_0
tokenizers 0.13.3 pypi_0 pypi
torch 1.13.1 pypi_0 pypi
torchmetrics 0.11.4 pypi_0 pypi
torchvision 0.14.1 pypi_0 pypi
tqdm 4.65.0 pypi_0 pypi
transformers 4.29.1 pypi_0 pypi
typing-extensions 4.5.0 pypi_0 pypi
urllib3 2.0.2 pypi_0 pypi
wheel 0.38.4 py37h06a4308_0
xxhash 3.2.0 pypi_0 pypi
xz 5.4.2 h5eee18b_0
yarl 1.9.2 pypi_0 pypi
zipp 3.15.0 pypi_0 pypi
zlib 1.2.13 h5eee18b_0
zss 1.2.0 pypi_0 pypi
Update: In fact, the test results are terrible with that install:
Total number of samples: 100, Tree Edit Distance (TED) based accuracy score: 0.17447007822672653, F1 accuracy score: 0.12529832935560858
transformers==4.25.1 fixes this:
Total number of samples: 100, Tree Edit Distance (TED) based accuracy score: 0.910367621155524, F1 accuracy score: 0.8373353989155693
Thank you. You saved my life!
Thank you. You saved my life!
It's hard to believe that throughout all of the documentation that I searched through, I cannot find this fix anywhere else - even on the Google Colabs, which are the most recently updated docs. Thank you for this.