torchmetrics icon indicating copy to clipboard operation
torchmetrics copied to clipboard

Investigate use of `join` context for distributed sync

Open SkafteNicki opened this issue 3 years ago • 1 comments

🚀 Feature

Motivation

Based the problems from this issue: https://github.com/Lightning-AI/metrics/issues/1297 By implementing join context (https://pytorch.org/tutorials/advanced/generic_join.html) for our distributed syncronization we would remove the limitation that to correctly calculate a metric the number of samples needs to be divisible by num_gpus * batch_size (because pytorch by default is adding additional samples to load balance).

Pitch

Base class should derive from Joinable class and implement appropriate methods. It should hopefully not be too much trouble as all the sync logic is already encapsulated in a function.

Alternatives

Additional context

SkafteNicki avatar Nov 15 '22 09:11 SkafteNicki

May using samplers like the one here be a temporary hack to obtain correct results?

tabmoo avatar Nov 29 '22 11:11 tabmoo