Investigate use of `join` context for distributed sync
🚀 Feature
Motivation
Based the problems from this issue: https://github.com/Lightning-AI/metrics/issues/1297
By implementing join context (https://pytorch.org/tutorials/advanced/generic_join.html) for our distributed syncronization we would remove the limitation that to correctly calculate a metric the number of samples needs to be divisible by num_gpus * batch_size (because pytorch by default is adding additional samples to load balance).
Pitch
Base class should derive from Joinable class and implement appropriate methods. It should hopefully not be too much trouble as all the sync logic is already encapsulated in a function.
Alternatives
Additional context
May using samplers like the one here be a temporary hack to obtain correct results?