torchmetrics icon indicating copy to clipboard operation
torchmetrics copied to clipboard

Refactor BERTScore tests

Open stancld opened this issue 3 years ago • 7 comments

What does this PR do?

Fixes #1110

Following on this PR, I'd also suggest a refactor of BERTScore itself as its implementation is a bit obscure.

Before submitting

  • [x] Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • [x] Did you read the contributor guideline, Pull Request section?
  • [x] Did you make sure to update the docs?
  • [x] Did you write any new necessary tests?

PR review

Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

stancld avatar Jul 29 '22 13:07 stancld

@Borda Tests should be hopefully refactored and fixed now. I would also suggest slightly refactoring functional and class metrics themselves. :]

stancld avatar Aug 12 '22 09:08 stancld

@stancld, how is it going here? :)

Borda avatar Sep 07 '22 08:09 Borda

@Borda Currently trying to resolve why the tests freeze on ddp=True, dist_sync_on_step=True... It doesn't happen always but definitely most of the time 🤔 Otherwise, it seems okay, so I'll try to dig a bit deeper. Some issues with ddp=True often perplex me 😬

stancld avatar Sep 08 '22 12:09 stancld

@SkafteNicki @justusschock, mind having a look? :rabbit:

Borda avatar Sep 15 '22 07:09 Borda

Currently trying to resolve why the tests freeze on ddp=True, dist_sync_on_step=True... It doesn't happen always but definitely most of the time

pls @SkafteNicki ^^

Borda avatar Oct 05 '22 10:10 Borda

@Borda @SkafteNicki I found out the problem is very likely on the side of reference score, where data are processed within a Pool:

def get_idf_dict(arr, tokenizer, nthreads=4):
    """
    Returns mapping from word piece index to its inverse document frequency.
    Args:
        - :param: `arr` (list of str) : sentences to process.
        - :param: `tokenizer` : a BERT tokenizer corresponds to `model`.
        - :param: `nthreads` (int) : number of CPU threads to use
    """
    idf_count = Counter()
    num_docs = len(arr)

    process_partial = partial(process, tokenizer=tokenizer)

    with Pool(nthreads) as p:
        idf_count.update(chain.from_iterable(p.map(process_partial, arr)))

    idf_dict = defaultdict(lambda: log((num_docs + 1) / (1)))
    idf_dict.update({idx: log((num_docs + 1) / (c + 1)) for (idx, c) in idf_count.items()})
    return idf_dict

which eventually likely results in those errors we're facing.

multiprocessing.pool.RemoteTraceback:                                                                                                                                                                                                         
"""                                                                                                                                                                                                                                           
Traceback (most recent call last):                                                                                                                                                                                                            
  File "/Users/stancld/miniconda3/envs/metrics/lib/python3.9/multiprocessing/pool.py", line 125, in worker                                                                                                                                    
    result = (True, func(*args, **kwds))                                                                                                                                                                                                      
  File "/Users/stancld/miniconda3/envs/metrics/lib/python3.9/multiprocessing/pool.py", line 51, in starmapstar                                                                                                                                
    return list(itertools.starmap(args[0], args[1]))                                                                                                                                                                                          
  File "/Users/stancld/Documents/metrics/tests/unittests/text/helpers.py", line 120, in _class_test                                                                                                                                           
    sk_batch_result = sk_metric(preds[i], targets[i], **batch_kwargs_update)                                                                                                                                                                  
  File "/Users/stancld/Documents/metrics/tests/unittests/text/helpers.py", line 449, in run_test                                                                                                                                              
    return function(*args, **kwargs)                                                                                                                                                                                                          
  File "/Users/stancld/Documents/metrics/tests/unittests/text/test_bertscore.py", line 54, in _reference_bert_score                                                                                                                           
    score_tuple = original_bert_score(                                                                                                                                                                                                        
  File "/Users/stancld/miniconda3/envs/metrics/lib/python3.9/site-packages/bert_score/score.py", line 124, in score                                                                                                                           
    idf_dict = get_idf_dict(refs, tokenizer, nthreads=nthreads)                                                                                                                                                                               
  File "/Users/stancld/miniconda3/envs/metrics/lib/python3.9/site-packages/bert_score/utils.py", line 338, in get_idf_dict                                                                                                                    
    with Pool(nthreads) as p:                                                                                                                                                                                                                 
  File "/Users/stancld/miniconda3/envs/metrics/lib/python3.9/multiprocessing/context.py", line 119, in Pool                                                                                                                                   
    return Pool(processes, initializer, initargs, maxtasksperchild,                                                                                                                                                                           
  File "/Users/stancld/miniconda3/envs/metrics/lib/python3.9/multiprocessing/pool.py", line 212, in __init__                                                                                                                                  
    self._repopulate_pool()                                                                                                                                                                                                                   
  File "/Users/stancld/miniconda3/envs/metrics/lib/python3.9/multiprocessing/pool.py", line 303, in _repopulate_pool                                                                                                                          
    return self._repopulate_pool_static(self._ctx, self.Process,                                                                                                                                                                              
  File "/Users/stancld/miniconda3/envs/metrics/lib/python3.9/multiprocessing/pool.py", line 326, in _repopulate_pool_static                                                                                                                   
    w.start()                                                                                                                                                                                                                                 
  File "/Users/stancld/miniconda3/envs/metrics/lib/python3.9/multiprocessing/process.py", line 118, in start                                                                                                                                  
    assert not _current_process._config.get('daemon'), \
AssertionError: daemonic processes are not allowed to have children
"""

stancld avatar Oct 12 '22 21:10 stancld

tm [[0.8306975  0.8739164  0.72068065 0.8306975 ]                                                                                                                                                                                             
 [0.8475447  0.83431923 0.69224256 0.8475447 ]                                                                                                                                                                                                
 [0.80905366 0.81940556 0.6781949  0.80905366]                                                                                                                                                                                                
 [0.7635116  0.81414014 0.650942   0.7635116 ]                                                                                                                                                                                                
 [0.7138042  0.77494    0.63750863 0.7138042 ]                                                                                                                                                                                                
 [0.6709998  0.73885846 0.6477587  0.6709998 ]                                                                                                                                                                                                
 [0.6260605  0.7129422  0.64940983 0.6260605 ]
 [0.5878769  0.6894183  0.64766484 0.5878769 ]
 [0.57530874 0.69059736 0.6506037  0.57530874]
 [0.5705996  0.6988218  0.6483182  0.5705996 ]
 [0.57114387 0.71416587 0.6398848  0.57114387]
 [0.61254483 0.7734972  0.6539489  0.61254483]
 [0.66501355 0.7811058  0.6300469  0.66501355]]
reference tensor([[0.7207, 0.8307, 0.8307, 0.8739],
        [0.6922, 0.8475, 0.8475, 0.8343],
        [0.6782, 0.8091, 0.8091, 0.8194],
        [0.6509, 0.7635, 0.7635, 0.8141],
        [0.6375, 0.7138, 0.7138, 0.7749],
        [0.6478, 0.6710, 0.6710, 0.7389],
        [0.6494, 0.6261, 0.6261, 0.7129],
        [0.6477, 0.5879, 0.5879, 0.6894],
        [0.6506, 0.5753, 0.5753, 0.6906],
        [0.6483, 0.5706, 0.5706, 0.6988],
        [0.6399, 0.5711, 0.5711, 0.7142],
        [0.6539, 0.6125, 0.6125, 0.7735],
        [0.6300, 0.6650, 0.6650, 0.7811]])

It looks like when DDP=True and dyst_sync_on_step=False, the order of batches is sometimes different between TM and reference metric. Would it be viable solution to take an average over examples? I believe, in a real scenario, one would be mainly interested in an aggregated metric. Actually, dunno which implementation tends to switch batch order 🤔

Dimension is [num_layers. num_examples].

cc @Borda

stancld avatar Oct 12 '22 22:10 stancld

@stancld I also noticed that sometimes the order of the batches is not the same, and it is due to the gathering over multiple processes on torchmetrics side. It usually does not matter because we do aggregation but for some metrics it is hard to get the tests right. I would suggest we also just check that the aggregated value works in this case.

SkafteNicki avatar Oct 24 '22 09:10 SkafteNicki

Hey @Borda @SkafteNicki -- Do you know if there's been recently any update to dist_sync_on_step method? There are several failing tests with dist_sync_on_step=True, but I don't understand why. Also, they look to be a bit flaky according to my testing on local 🤔

stancld avatar Oct 28 '22 21:10 stancld

@justusschock could you pls have look? :otter:

Borda avatar Oct 28 '22 22:10 Borda

Those failing tests are also pretty peculiar as from printed results it looks to me the results from our metric and a reference one are equivalent.

stancld avatar Oct 29 '22 07:10 stancld