torchmetrics
torchmetrics copied to clipboard
Support metrics reduction outside of DDP
🚀 Feature
Support metrics reduction outside of DDP.
Motivation
When used within DDP, torchmetrics objects support automatic syncing and reduction across ranks. However, there doesn't seem to be support for reduction outside DDP. This will be a good feature to have because it allows using torchmetrics for distributed evaluation using frameworks other than DDP.
Pitch
Let's say we are computing metrics on a large dataset. Each worker receives a shard of the dataset and computes the metric for the shard and we collect the metrics object from all workers: metrics = [metric_0, metric_1, metric_2, metric_3, ...]
. To compute the final metric across the entire dataset, we need a mechanism to reduce the metrics. By looking at the torchmetrics/src/torchmetrics/metric.py
, I see one potential solution:
metric_reduced = MetricType() # Same type as the metrics in metrics list
for metric in metrics:
metric_reduced._reduce_states(metric.metric_state)
# Compute final metric
final_metric = metric_reduced.compute()
However, this approach relies on the private member function _reduce_states
. It would be great if torchmetrics can offer a all_reduce(metrics: Iterable[MetricType]) -> MetricType
function that achieves the same functionality.
Alternatives
Additional context
Hi! thanks for your contribution!, great first issue!
cc: @justusschock
Hey @ytang137 and sorry for the late reply.
A metric takes the following input kwargs on initialization:
- process_group: The process group on which the synchronization is called. Default is the world.
- dist_sync_fn: Function that performs the allgather option on the metric state. Default is an custom implementation that calls
torch.distributed.all_gather
internally. - distributed_available_fn: Function that checks if the distributed backend is available. Defaults to a check of
torch.distributed.is_available()
andtorch.distributed.is_initialized()
.
These are used for all the syncs we run internally. Can you give an example where these functions wouldn't be sufficient? In general, a reduction is always applied after an all_gather when all states have been synced already.
Hi @justusschock , thanks for getting back to me. The use case I had in mind is to use TorchMetrics out of the context of DDP, or even out of PyTorch all together, where concepts such as process_group
don't apply.
Consider this example: we have inference results from a large dataset saved in a database, and we want to compute metrics grouped by dates and also report metrics aggregated over the entire dataset in the end. One approach would be to use Dask or Modin to calculate the metrics distributedly: dataset.groupby("date").apply(my_metric_function)
. The my_metric_function
here returns a TorchMetric object containing states updated by the data from a specific date. The groupby operation therefore returns a series of TorchMetric objects. It would be very nice to have an all_reduce(metrics: Iterable[Metric]) -> Metric
function that performs reduction to compute the metric over the entire dataset.
Is it currently possible to reduce metric this way? Correct me if I'm wrong - it seems that the concepts of process_group
, dist_sync_fn
, and distributed_available_fn
don't apply in this example, or maybe these parameters can be used in certain way to support this use case? Thanks.