torchmetrics icon indicating copy to clipboard operation
torchmetrics copied to clipboard

Support metrics reduction outside of DDP

Open ytang137 opened this issue 10 months ago • 4 comments

🚀 Feature

Support metrics reduction outside of DDP.

Motivation

When used within DDP, torchmetrics objects support automatic syncing and reduction across ranks. However, there doesn't seem to be support for reduction outside DDP. This will be a good feature to have because it allows using torchmetrics for distributed evaluation using frameworks other than DDP.

Pitch

Let's say we are computing metrics on a large dataset. Each worker receives a shard of the dataset and computes the metric for the shard and we collect the metrics object from all workers: metrics = [metric_0, metric_1, metric_2, metric_3, ...]. To compute the final metric across the entire dataset, we need a mechanism to reduce the metrics. By looking at the torchmetrics/src/torchmetrics/metric.py, I see one potential solution:

metric_reduced = MetricType()  # Same type as the metrics in metrics list
for metric in metrics:
    metric_reduced._reduce_states(metric.metric_state)

# Compute final metric
final_metric = metric_reduced.compute()

However, this approach relies on the private member function _reduce_states. It would be great if torchmetrics can offer a all_reduce(metrics: Iterable[MetricType]) -> MetricType function that achieves the same functionality.

Alternatives

Additional context

ytang137 avatar Sep 07 '23 02:09 ytang137

Hi! thanks for your contribution!, great first issue!

github-actions[bot] avatar Sep 07 '23 02:09 github-actions[bot]

cc: @justusschock

SkafteNicki avatar Sep 07 '23 05:09 SkafteNicki

Hey @ytang137 and sorry for the late reply.

A metric takes the following input kwargs on initialization:

  • process_group: The process group on which the synchronization is called. Default is the world.
  • dist_sync_fn: Function that performs the allgather option on the metric state. Default is an custom implementation that calls torch.distributed.all_gather internally.
  • distributed_available_fn: Function that checks if the distributed backend is available. Defaults to a check of torch.distributed.is_available() and torch.distributed.is_initialized().

These are used for all the syncs we run internally. Can you give an example where these functions wouldn't be sufficient? In general, a reduction is always applied after an all_gather when all states have been synced already.

justusschock avatar Sep 22 '23 13:09 justusschock

Hi @justusschock , thanks for getting back to me. The use case I had in mind is to use TorchMetrics out of the context of DDP, or even out of PyTorch all together, where concepts such as process_group don't apply.

Consider this example: we have inference results from a large dataset saved in a database, and we want to compute metrics grouped by dates and also report metrics aggregated over the entire dataset in the end. One approach would be to use Dask or Modin to calculate the metrics distributedly: dataset.groupby("date").apply(my_metric_function). The my_metric_function here returns a TorchMetric object containing states updated by the data from a specific date. The groupby operation therefore returns a series of TorchMetric objects. It would be very nice to have an all_reduce(metrics: Iterable[Metric]) -> Metric function that performs reduction to compute the metric over the entire dataset.

Is it currently possible to reduce metric this way? Correct me if I'm wrong - it seems that the concepts of process_group, dist_sync_fn, and distributed_available_fn don't apply in this example, or maybe these parameters can be used in certain way to support this use case? Thanks.

ytang137 avatar Sep 23 '23 18:09 ytang137