torchmetrics
torchmetrics copied to clipboard
Refactor BERTScore tests
What does this PR do?
Fixes #1110
Following on this PR, I'd also suggest a refactor of BERTScore itself as its implementation is a bit obscure.
Before submitting
- [x] Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
- [x] Did you read the contributor guideline, Pull Request section?
- [x] Did you make sure to update the docs?
- [x] Did you write any new necessary tests?
PR review
Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃
@Borda Tests should be hopefully refactored and fixed now. I would also suggest slightly refactoring functional and class metrics themselves. :]
@stancld, how is it going here? :)
@Borda Currently trying to resolve why the tests freeze on ddp=True, dist_sync_on_step=True... It doesn't happen always but definitely most of the time 🤔
Otherwise, it seems okay, so I'll try to dig a bit deeper. Some issues with ddp=True often perplex me 😬
@SkafteNicki @justusschock, mind having a look? :rabbit:
Currently trying to resolve why the tests freeze on
ddp=True, dist_sync_on_step=True... It doesn't happen always but definitely most of the time
pls @SkafteNicki ^^
@Borda @SkafteNicki I found out the problem is very likely on the side of reference score, where data are processed within a Pool:
def get_idf_dict(arr, tokenizer, nthreads=4):
"""
Returns mapping from word piece index to its inverse document frequency.
Args:
- :param: `arr` (list of str) : sentences to process.
- :param: `tokenizer` : a BERT tokenizer corresponds to `model`.
- :param: `nthreads` (int) : number of CPU threads to use
"""
idf_count = Counter()
num_docs = len(arr)
process_partial = partial(process, tokenizer=tokenizer)
with Pool(nthreads) as p:
idf_count.update(chain.from_iterable(p.map(process_partial, arr)))
idf_dict = defaultdict(lambda: log((num_docs + 1) / (1)))
idf_dict.update({idx: log((num_docs + 1) / (c + 1)) for (idx, c) in idf_count.items()})
return idf_dict
which eventually likely results in those errors we're facing.
multiprocessing.pool.RemoteTraceback:
"""
Traceback (most recent call last):
File "/Users/stancld/miniconda3/envs/metrics/lib/python3.9/multiprocessing/pool.py", line 125, in worker
result = (True, func(*args, **kwds))
File "/Users/stancld/miniconda3/envs/metrics/lib/python3.9/multiprocessing/pool.py", line 51, in starmapstar
return list(itertools.starmap(args[0], args[1]))
File "/Users/stancld/Documents/metrics/tests/unittests/text/helpers.py", line 120, in _class_test
sk_batch_result = sk_metric(preds[i], targets[i], **batch_kwargs_update)
File "/Users/stancld/Documents/metrics/tests/unittests/text/helpers.py", line 449, in run_test
return function(*args, **kwargs)
File "/Users/stancld/Documents/metrics/tests/unittests/text/test_bertscore.py", line 54, in _reference_bert_score
score_tuple = original_bert_score(
File "/Users/stancld/miniconda3/envs/metrics/lib/python3.9/site-packages/bert_score/score.py", line 124, in score
idf_dict = get_idf_dict(refs, tokenizer, nthreads=nthreads)
File "/Users/stancld/miniconda3/envs/metrics/lib/python3.9/site-packages/bert_score/utils.py", line 338, in get_idf_dict
with Pool(nthreads) as p:
File "/Users/stancld/miniconda3/envs/metrics/lib/python3.9/multiprocessing/context.py", line 119, in Pool
return Pool(processes, initializer, initargs, maxtasksperchild,
File "/Users/stancld/miniconda3/envs/metrics/lib/python3.9/multiprocessing/pool.py", line 212, in __init__
self._repopulate_pool()
File "/Users/stancld/miniconda3/envs/metrics/lib/python3.9/multiprocessing/pool.py", line 303, in _repopulate_pool
return self._repopulate_pool_static(self._ctx, self.Process,
File "/Users/stancld/miniconda3/envs/metrics/lib/python3.9/multiprocessing/pool.py", line 326, in _repopulate_pool_static
w.start()
File "/Users/stancld/miniconda3/envs/metrics/lib/python3.9/multiprocessing/process.py", line 118, in start
assert not _current_process._config.get('daemon'), \
AssertionError: daemonic processes are not allowed to have children
"""
tm [[0.8306975 0.8739164 0.72068065 0.8306975 ]
[0.8475447 0.83431923 0.69224256 0.8475447 ]
[0.80905366 0.81940556 0.6781949 0.80905366]
[0.7635116 0.81414014 0.650942 0.7635116 ]
[0.7138042 0.77494 0.63750863 0.7138042 ]
[0.6709998 0.73885846 0.6477587 0.6709998 ]
[0.6260605 0.7129422 0.64940983 0.6260605 ]
[0.5878769 0.6894183 0.64766484 0.5878769 ]
[0.57530874 0.69059736 0.6506037 0.57530874]
[0.5705996 0.6988218 0.6483182 0.5705996 ]
[0.57114387 0.71416587 0.6398848 0.57114387]
[0.61254483 0.7734972 0.6539489 0.61254483]
[0.66501355 0.7811058 0.6300469 0.66501355]]
reference tensor([[0.7207, 0.8307, 0.8307, 0.8739],
[0.6922, 0.8475, 0.8475, 0.8343],
[0.6782, 0.8091, 0.8091, 0.8194],
[0.6509, 0.7635, 0.7635, 0.8141],
[0.6375, 0.7138, 0.7138, 0.7749],
[0.6478, 0.6710, 0.6710, 0.7389],
[0.6494, 0.6261, 0.6261, 0.7129],
[0.6477, 0.5879, 0.5879, 0.6894],
[0.6506, 0.5753, 0.5753, 0.6906],
[0.6483, 0.5706, 0.5706, 0.6988],
[0.6399, 0.5711, 0.5711, 0.7142],
[0.6539, 0.6125, 0.6125, 0.7735],
[0.6300, 0.6650, 0.6650, 0.7811]])
It looks like when DDP=True and dyst_sync_on_step=False, the order of batches is sometimes different between TM and reference metric. Would it be viable solution to take an average over examples? I believe, in a real scenario, one would be mainly interested in an aggregated metric. Actually, dunno which implementation tends to switch batch order 🤔
Dimension is [num_layers. num_examples].
cc @Borda
@stancld I also noticed that sometimes the order of the batches is not the same, and it is due to the gathering over multiple processes on torchmetrics side. It usually does not matter because we do aggregation but for some metrics it is hard to get the tests right. I would suggest we also just check that the aggregated value works in this case.
Hey @Borda @SkafteNicki -- Do you know if there's been recently any update to dist_sync_on_step method? There are several failing tests with dist_sync_on_step=True, but I don't understand why. Also, they look to be a bit flaky according to my testing on local 🤔
@justusschock could you pls have look? :otter:
Those failing tests are also pretty peculiar as from printed results it looks to me the results from our metric and a reference one are equivalent.