Can't calculate combined metric (WER + CER)
Hello. Reproduction code:
import evaluate
asr_metrics = evaluate.combine(["wer","cer"])
predictions = ["this is the prediction", "there is an other sample"]
references = ["this is the reference", "there is another one"]
asr_metrics.compute(predictions=predictions, references=references)
Code output:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[15], line 4
1 predictions = ["this is the prediction", "there is an other sample"]
2 references = ["this is the reference", "there is another one"]
----> 4 asr_metrics.compute(predictions=predictions, references=references)
File ~/.local/lib/python3.10/site-packages/evaluate/module.py:976, in CombinedEvaluations.compute(self, predictions, references, **kwargs)
973 batch = {"predictions": predictions, "references": references, **kwargs}
974 results.append(evaluation_module.compute(**batch))
--> 976 return self._merge_results(results)
File ~/.local/lib/python3.10/site-packages/evaluate/module.py:980, in CombinedEvaluations._merge_results(self, results)
978 def _merge_results(self, results):
979 merged_results = {}
--> 980 results_keys = list(itertools.chain.from_iterable([r.keys() for r in results]))
981 duplicate_keys = {item for item, count in collections.Counter(results_keys).items() if count > 1}
983 duplicate_names = [
984 item for item, count in collections.Counter(self.evaluation_module_names).items() if count > 1
985 ]
File ~/.local/lib/python3.10/site-packages/evaluate/module.py:980, in <listcomp>(.0)
978 def _merge_results(self, results):
979 merged_results = {}
--> 980 results_keys = list(itertools.chain.from_iterable([r.keys() for r in results]))
981 duplicate_keys = {item for item, count in collections.Counter(results_keys).items() if count > 1}
983 duplicate_names = [
984 item for item, count in collections.Counter(self.evaluation_module_names).items() if count > 1
985 ]
AttributeError: 'float' object has no attribute 'keys'
Python verion:
Python 3.10.12
Library versions:
Package Version
------------------------- -------------
aiofiles 23.2.1
aiohttp 3.9.0
aiosignal 1.3.1
altair 5.1.2
annotated-types 0.6.0
anyio 3.7.1
argon2-cffi 23.1.0
argon2-cffi-bindings 21.2.0
arrow 1.3.0
asttokens 2.4.1
async-lru 2.0.4
async-timeout 4.0.3
attrs 23.1.0
audioread 3.0.1
Babel 2.13.1
beautifulsoup4 4.12.2
bleach 6.1.0
blinker 1.4
certifi 2023.11.17
cffi 1.16.0
charset-normalizer 3.3.2
click 8.1.7
colorama 0.4.6
comm 0.2.0
command-not-found 0.3
contourpy 1.2.0
cryptography 3.4.8
cycler 0.12.1
datasets 2.15.0
dbus-python 1.2.18
debugpy 1.8.0
decorator 5.1.1
defusedxml 0.7.1
dill 0.3.7
distro 1.7.0
distro-info 1.1+ubuntu0.1
et-xmlfile 1.1.0
evaluate 0.4.1
exceptiongroup 1.1.3
executing 2.0.1
fastapi 0.104.1
fastjsonschema 2.19.0
ffmpy 0.3.1
filelock 3.13.1
fonttools 4.44.3
fqdn 1.5.1
frozenlist 1.4.0
fsspec 2023.10.0
gradio 4.4.1
gradio_client 0.7.0
h11 0.14.0
httpcore 1.0.2
httplib2 0.20.2
httpx 0.25.1
huggingface-hub 0.19.4
idna 3.4
importlib-metadata 4.6.4
importlib-resources 6.1.1
ipykernel 6.26.0
ipython 8.17.2
ipywidgets 8.1.1
isoduration 20.11.0
jedi 0.19.1
jeepney 0.7.1
Jinja2 3.1.2
jiwer 3.0.3
joblib 1.3.2
json5 0.9.14
jsonpointer 2.4
jsonschema 4.20.0
jsonschema-specifications 2023.11.1
jupyter_client 8.6.0
jupyter_core 5.5.0
jupyter-events 0.9.0
jupyter-lsp 2.2.0
jupyter_server 2.10.1
jupyter_server_terminals 0.4.4
jupyterlab 4.0.9
jupyterlab-pygments 0.2.2
jupyterlab_server 2.25.2
jupyterlab-widgets 3.0.9
keyring 23.5.0
kiwisolver 1.4.5
launchpadlib 1.10.16
lazr.restfulclient 0.14.4
lazr.uri 1.0.6
lazy_loader 0.3
librosa 0.10.1
llvmlite 0.41.1
markdown-it-py 3.0.0
MarkupSafe 2.1.3
matplotlib 3.8.2
matplotlib-inline 0.1.6
mdurl 0.1.2
mistune 3.0.2
more-itertools 8.10.0
mpmath 1.3.0
msgpack 1.0.7
multidict 6.0.4
multiprocess 0.70.15
nbclient 0.9.0
nbconvert 7.11.0
nbformat 5.9.2
nest-asyncio 1.5.8
netifaces 0.11.0
networkx 3.2.1
notebook_shim 0.2.3
numba 0.58.1
numpy 1.26.2
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 8.9.2.26
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-nccl-cu12 2.18.1
nvidia-nvjitlink-cu12 12.3.101
nvidia-nvtx-cu12 12.1.105
oauthlib 3.2.0
openpyxl 3.1.2
orjson 3.9.10
overrides 7.4.0
packaging 23.2
pandas 2.1.3
pandocfilters 1.5.0
parso 0.8.3
pexpect 4.8.0
Pillow 10.1.0
pip 23.3.1
platformdirs 4.0.0
plotly 5.18.0
pooch 1.8.0
prometheus-client 0.18.0
prompt-toolkit 3.0.41
psutil 5.9.6
ptyprocess 0.7.0
pure-eval 0.2.2
pyarrow 14.0.1
pyarrow-hotfix 0.5
pycparser 2.21
pydantic 2.5.1
pydantic_core 2.14.3
pydub 0.25.1
Pygments 2.17.1
PyGObject 3.42.1
PyJWT 2.3.0
pyparsing 2.4.7
python-apt 2.4.0+ubuntu2
python-dateutil 2.8.2
python-json-logger 2.0.7
python-multipart 0.0.6
pytz 2023.3.post1
PyYAML 5.4.1
pyzmq 25.1.1
rapidfuzz 3.5.2
referencing 0.31.0
regex 2023.10.3
requests 2.31.0
responses 0.18.0
rfc3339-validator 0.1.4
rfc3986-validator 0.1.1
rich 13.7.0
rpds-py 0.13.0
safetensors 0.4.0
scikit-learn 1.3.2
scipy 1.11.4
SecretStorage 3.3.1
semantic-version 2.10.0
Send2Trash 1.8.2
setuptools 59.6.0
shellingham 1.5.4
six 1.16.0
sniffio 1.3.0
soundfile 0.12.1
soupsieve 2.5
soxr 0.3.7
stack-data 0.6.3
starlette 0.27.0
sympy 1.12
systemd-python 234
tenacity 8.2.3
terminado 0.18.0
threadpoolctl 3.2.0
tinycss2 1.2.1
tokenizers 0.15.0
tomli 2.0.1
tomlkit 0.12.0
toolz 0.12.0
torch 2.1.1
torchaudio 2.1.1
torchvision 0.16.1
tornado 6.3.3
tqdm 4.66.1
traitlets 5.13.0
transformers 4.35.2
triton 2.1.0
typer 0.9.0
types-python-dateutil 2.8.19.14
typing_extensions 4.8.0
tzdata 2023.3
ubuntu-advantage-tools 8001
ufw 0.36.1
unattended-upgrades 0.1
uri-template 1.3.0
urllib3 2.1.0
uvicorn 0.24.0.post1
wadllib 1.3.6
wcwidth 0.2.10
webcolors 1.13
webencodings 0.5.1
websocket-client 1.6.4
websockets 11.0.3
wheel 0.37.1
widgetsnbextension 4.0.9
xxhash 3.4.1
yarl 1.9.2
zipp 1.0.0
My current hypothesis is that the problem is in the data type returned by the metrics. Metrics with no problem combining returns the dictionary:
m_accuracy = evaluate.load("accuracy")
m_recall = evaluate.load("recall")
m_precision = evaluate.load("precision")
m_f1 = evaluate.load("f1")
pred = [1,0,1]
ref = [1,1,1]
result_accuracy = m_accuracy.compute(predictions=pred, references=ref)
result_recall = m_recall.compute(predictions=pred, references=ref)
result_precision = m_precision.compute(predictions=pred, references=ref)
result_f1 = m_f1.compute(predictions=pred, references=ref)
print("Accuracy:", type(result_accuracy), result_accuracy)
print("Recall:", type(result_recall), result_recall)
print("precision:", type(result_precision), result_precision)
print("F1:", type(result_f1), result_f1)
Output:
Accuracy: <class 'dict'> {'accuracy': 0.6666666666666666}
Recall: <class 'dict'> {'recall': 0.6666666666666666}
precision: <class 'dict'> {'precision': 1.0}
F1: <class 'dict'> {'f1': 0.8}
Metrics that have problems (such as WER, CER) return a value when trying to combine:
m_wer = evaluate.load("wer")
m_cer = evaluate.load("cer")
pred = ["test", "Test"]
ref = ["TEST", "Test"]
result_wer = m_wer.compute(predictions=pred, references=ref)
result_cer = m_cer.compute(predictions=pred, references=ref)
print("WER:", type(result_wer), result_wer)
print("CER:", type(result_cer), result_cer)
Output:
WER: <class 'float'> 0.5
CER: <class 'float'> 0.5
Naturally float has no keys attribute...
It is possible to make changes to the code of these files:
https://github.com/huggingface/evaluate/blob/main/metrics/wer/wer.py#98 https://github.com/huggingface/evaluate/blob/main/metrics/wer/wer.py#106
https://github.com/blademoon/evaluate/blob/main/metrics/cer/cer.py#140 https://github.com/blademoon/evaluate/blob/main/metrics/cer/cer.py#159
may solve the problem. Hypothetical, you just need to "wrap" the return value in a dictionary with the appropriate key.