torchmetrics
torchmetrics copied to clipboard
Add `EvaluationDistributedSampler` and examples on distributed evaluation
What does this PR do?
Fixes https://github.com/Lightning-AI/torchmetrics/issues/1338
The original issue is about if we should implement a join context such that metrics could be evaluated on uneven number of samples in distributed settings. Just to remind, we normally discourage users from evaluating in distributed because the default distributed sampler from Pytorch will add additional samples to make all processes do even work, which messes with results.
After investigating this issue, it seems that we do not need a join context at all due to the custom synchronization we have for metrics. To understand this we need to look at the two different states we can have: tensor state and list of tensor states.
- For tensor states the logic is fairly simple: even if rank 0 is evaluated on more samples or more batches than rank 1, we still only need to do one all-gather operation regardless of how many samples/batches each rank has seen.
- For list states we need are saved by the custom logic we have. Imaging that rank 0 state is a list of two tensors
[t_01, t_02]and rank 1 state is a list of one tensor[t_11](rank 0 have seen one more batch than rank 1). We list states are encountered internally we make sure to concatenate the states into one tensor to not need to callallgatherfor each tensor in the list https://github.com/Lightning-AI/torchmetrics/blob/879595de67ab35d891a227e3254e0b5a26a050f0/src/torchmetrics/metric.py#L418-L419 such after this each state is a single tensort_0andt_1but clearlyt_0.shape != t_1.shape. Again, internally we deal with this by padding to same size and then doing a all gather: https://github.com/Lightning-AI/torchmetrics/blob/879595de67ab35d891a227e3254e0b5a26a050f0/src/torchmetrics/utilities/distributed.py#L136-L148
Thus in both cases, even if one rank sees more samples/batches, we still do the same number of distributed operations per rank, which should mean that everything works.
To highlight this feature of TM this PR does a couple of things:
- Introduce a new
EvaluationDistributedSamplerthat does not add extra samplers. Thus, users can use this as a drop in replacement for anyDistributedSamplerif they want to do proper distributed evaluation (else they just need to secure that number of samples are even divisible by the number of processes). - Add unittests that supports the above
- Add example on how to do this distributed evaluation in both lightning + standard torch
Before submitting
- [ ] Was this discussed/agreed via a Github issue? (no need for typos and docs improvements)
- [ ] Did you read the contributor guideline, Pull Request section?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?
PR review
Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃
:books: Documentation preview :books:: https://torchmetrics--1886.org.readthedocs.build/en/1886/
Codecov Report
Merging #1886 (311bce3) into master (29f3289) will decrease coverage by
0%. The diff coverage is50%.
Additional details and impacted files
@@ Coverage Diff @@
## master #1886 +/- ##
======================================
- Coverage 87% 87% -0%
======================================
Files 270 270
Lines 15581 15592 +11
======================================
+ Hits 13483 13488 +5
- Misses 2098 2104 +6
@SkafteNicki, how is it going here? do you think we could land it for the next 1.1 release...
calling @awaelchli for distributed review :)
can we please add tests for validation and training as well? And maybe an fsdp test? Also some notes on caveats might be good to add to the sampler docs
You are right that we need to test this feature better to clearly state the limitations. I am going to remove it from the 1.1 milestone to future because it is not important to get done right now.
Converted to draft until better tested.