torchmetrics icon indicating copy to clipboard operation
torchmetrics copied to clipboard

No backend type associated with device type cpu

Open rballeba opened this issue 1 year ago • 22 comments

🐛 Bug

Metrics (predefined in library and custom implementations) using concatenation dist_reduce_fx="cat" and CPU computation compute_on_cpu=True raise an error when training in multiple GPUs (ddp). The concrete error is RuntimeError: No backend type associated with device type cpu.

To Reproduce

Code sample:

import torch
from lightning import Trainer, LightningModule
from torch.utils.data import DataLoader
from torchmetrics import AUROC


class LitModel(LightningModule):
    def __init__(self) -> None:
        super().__init__()
        self.layer = torch.nn.Linear(1, 1)
        self.auroc = AUROC(task="binary", compute_on_cpu=True)

    def training_step(self, x):
        preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]).cuda()
        target = torch.tensor([0, 0, 1, 1, 1]).cuda()
        self.auroc(preds, target)
        self.log("train_auroc", self.auroc, on_step=True, on_epoch=True)
        loss = self.layer(x).mean()
        return loss

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.1)

    def train_dataloader(self):
        return DataLoader(torch.randn(32, 1), batch_size=1)

Stacktrace

Traceback (most recent call last):
  File "/home/ruben/Documents/PhD/Research/Topological Deep Learning/lightning/pythonProject/main.py", line 35, in <module>
    main()
  File "/home/ruben/Documents/PhD/Research/Topological Deep Learning/lightning/pythonProject/main.py", line 31, in main
    trainer.fit(model)
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 544, in fit
    call._call_and_handle_interrupt(
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 580, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 987, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1033, in _run_stage
    self.fit_loop.run()
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 206, in run
    self.on_advance_end()
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 376, in on_advance_end
    call._call_callback_hooks(trainer, "on_train_epoch_end", monitoring_callbacks=False)
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 208, in _call_callback_hooks
    fn(trainer, trainer.lightning_module, *args, **kwargs)
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/callbacks/progress/tqdm_progress.py", line 281, in on_train_epoch_end
    self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/callbacks/progress/progress_bar.py", line 198, in get_metrics
    pbar_metrics = trainer.progress_bar_metrics
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1651, in progress_bar_metrics
    return self._logger_connector.progress_bar_metrics
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py", line 253, in progress_bar_metrics
    metrics = self.metrics["pbar"]
              ^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py", line 234, in metrics
    return self.trainer._results.metrics(on_step)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 483, in metrics
    value = self._get_cache(result_metric, on_step)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 447, in _get_cache
    result_metric.compute()
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 289, in wrapped_func
    self._computed = compute(*args, **kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 254, in compute
    return self.value.compute()
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/torchmetrics/metric.py", line 611, in wrapped_func
    with self.sync_context(
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/contextlib.py", line 137, in __enter__
    return next(self.gen)
           ^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/torchmetrics/metric.py", line 582, in sync_context
    self.sync(
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/torchmetrics/metric.py", line 531, in sync
    self._sync_dist(dist_sync_fn, process_group=process_group)
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/torchmetrics/metric.py", line 435, in _sync_dist
    output_dict = apply_to_collection(
                  ^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning_utilities/core/apply_func.py", line 72, in apply_to_collection
    return _apply_to_collection_slow(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning_utilities/core/apply_func.py", line 104, in _apply_to_collection_slow
    v = _apply_to_collection_slow(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning_utilities/core/apply_func.py", line 125, in _apply_to_collection_slow
    v = _apply_to_collection_slow(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/lightning_utilities/core/apply_func.py", line 96, in _apply_to_collection_slow
    return function(data, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/torchmetrics/utilities/distributed.py", line 127, in gather_all_tensors
    torch.distributed.all_gather(local_sizes, local_size, group=group)
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/torch/distributed/c10d_logger.py", line 47, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ruben/miniconda3/envs/sct/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 2808, in all_gather
    work = group.allgather([tensor_list], [tensor])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: No backend type associated with device type cpu

Expected behavior

Metric is computed properly merging the different lists in the differents processes in multi GPU training scenarios.

Environment

  • TorchMetrics version 1.3.2. Installed using pip
  • Python & PyTorch Version: 3.11 and 2.1.2, respectively.
  • Any other relevant information such as OS (e.g., Linux): Ubuntu 23.10

Additional context

Related bug in PyTorch Lightning

https://github.com/Lightning-AI/pytorch-lightning/issues/18803

rballeba avatar Mar 27 '24 15:03 rballeba

Hi! thanks for your contribution!, great first issue!

github-actions[bot] avatar Mar 27 '24 15:03 github-actions[bot]

I meet the same bug.

HGGshiwo avatar May 14 '24 03:05 HGGshiwo

I also met the same bug

SangbumChoi avatar Jun 21 '24 05:06 SangbumChoi

Hello there, any update in this issue?

Rbrq03 avatar Jul 09 '24 13:07 Rbrq03

Hi all, thanks for reporting this issue. I am currently looking into what can be done on to solve this issue. The compute_on_cpu argument was sadly never tested for multi-gpu setups, only single GPU.

SkafteNicki avatar Jul 22 '24 07:07 SkafteNicki

Hi, @SkafteNicki I also met the same bug when using MeanAveragePrecision, the error occurs in the default dist_sync_fn gather_all_tensors. It cannot successfully gather tensors when evaluation with multi-gpu setups. I hope the following function may be helpful:

def all_gather(data, group=None):
    """
    Run all_gather on arbitrary picklable data (not necessarily tensors)
    Args:
        data: any picklable object
    Returns:
        list[data]: list of data gathered from each rank
    """
    world_size = get_world_size()
    if world_size == 1:
        return [data]

    # serialized to a Tensor
    buffer = pickle.dumps(data)
    storage = torch.ByteStorage.from_buffer(buffer)
    tensor = torch.ByteTensor(storage).to("cuda")

    # obtain Tensor size of each rank
    local_size = torch.tensor([tensor.numel()], device="cuda")
    size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
    dist.all_gather(size_list, local_size, group)
    size_list = [int(size.item()) for size in size_list]
    max_size = max(size_list)

    # receiving Tensor from all ranks
    # we pad the tensor because torch all_gather does not support
    # gathering tensors of different shapes
    tensor_list = []
    for _ in size_list:
        tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
    if local_size != max_size:
        padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
        tensor = torch.cat((tensor, padding), dim=0)
    dist.all_gather(tensor_list, tensor, group)

    data_list = []
    for size, tensor in zip(size_list, tensor_list):
        buffer = tensor.cpu().numpy().tobytes()[:size]
        data_list.append(pickle.loads(buffer))

    return data_list

The function comes from the training reference from torchvision: https://github.com/pytorch/vision/blob/main/references/detection/utils.py

When I add the following codes to my script, I found it works in multi-gpus evaluation with compute_on_cpu=True. However, compute_on_cpu=False failed. Maybe gather_all_tensors can be extended with all_gather to apply to single gpu and multi-gpu, "computing_on_cpu" or not.

coco_evaluator = MeanAveragePrecision(iou_type=args.iou_type, backend=args.backend)
coco_evaluator.dist_sync_fn = utils.all_gather if args.evaluate_on_cpu else None

xiuqhou avatar Aug 19 '24 13:08 xiuqhou

Hi all, thanks for reporting this issue. I am currently looking into what can be done on to solve this issue. The compute_on_cpu argument was sadly never tested for multi-gpu setups, only single GPU.

@SkafteNicki lets add the first test for this multi-GPU so we can reproduce and prevent it in the future?

Borda avatar Aug 21 '24 12:08 Borda

I've looked into the problem and found out that the main reason for this error is that default distributed backend for lightning is nccl. If compute_on_cpu=True, it gives the error since all_gather operation is not supported on cpu. One way to resolve this is by using gloo backend, which allows all_gather on cpu.

davidgill97 avatar Aug 21 '24 13:08 davidgill97

I have also encountered the same issue when trying to use compute_on_cpu=True in a ddp setup. I tried to initialize the metric with compute_on_cpu and process_group, within on_fit_start function in lightning trainer:

def on_fit_start(self):
    cpu_comm = torch.distributed.new_group(backend="gloo") 
    self.metric = SomeMetric(..., compute_on_cpu=True, process_group=cpu_comm)

But this did not resolve the problem. The root cause seems to be that a duplicated instance of the metric class _ResultMetric is initialized during the evaluation loop, but the metadata does not include the compute_on_cpu or process_group arguments. So in the _ResultMetric class, gloo process group is not properly passed by so it defaults to "nccl" (torch.distributed.group.WORLD).

This seems to be where the _ResultMetric is initialized: https://github.com/Lightning-AI/pytorch-lightning/blob/32e7d32956e1685d36f2ab0ca3770baa2f76ce10/pytorch_lightning/trainer/connectors/logger_connector/result.py#L503

sandychoii avatar Sep 05 '24 08:09 sandychoii

This is a copy-paste of my reply to this issue: https://github.com/Lightning-AI/pytorch-lightning/issues/18803

I was having the same error message when using MeanAveragePrecision() on Databricks.

For me it worked adding the following three kwargs when the metric was initialized:

  • compute_on_cpu=False
  • sync_on_compute=False
  • dist_sync_on_step=True

All three arguments are needed to solve it in my case.

My code now looks like:

metric = MeanAveragePrecision(
          iou_type="segm", 
          class_metrics=True, 
          compute_on_cpu=False, 
          sync_on_compute=False,
          dist_sync_on_step=True, 
)

Holer90 avatar Sep 17 '24 13:09 Holer90

Thank you @Holer90 for sharing. Unfortunately your solution doesn't seem to work for me. It'd be useful to know a bit more about your configuration. In particular, what's your Trainer flags configuration, e.g. devices, strategy etc.? Thanks

mdifatta avatar Sep 19 '24 09:09 mdifatta

I have also encountered the same issue when trying to use compute_on_cpu=True in a ddp setup. I tried to initialize the metric with compute_on_cpu and process_group, within on_fit_start function in lightning trainer:

def on_fit_start(self):
    cpu_comm = torch.distributed.new_group(backend="gloo") 
    self.metric = SomeMetric(..., compute_on_cpu=True, process_group=cpu_comm)

But this did not resolve the problem. The root cause seems to be that a duplicated instance of the metric class _ResultMetric is initialized during the evaluation loop, but the metadata does not include the compute_on_cpu or process_group arguments. So in the _ResultMetric class, gloo process group is not properly passed by so it defaults to "nccl" (torch.distributed.group.WORLD).

This seems to be where the _ResultMetric is initialized: https://github.com/Lightning-AI/pytorch-lightning/blob/32e7d32956e1685d36f2ab0ca3770baa2f76ce10/pytorch_lightning/trainer/connectors/logger_connector/result.py#L503

Any updates on how to solve this?

championsnet avatar Apr 14 '25 10:04 championsnet

Any updates? Same issue

MeteorsHub avatar Jun 27 '25 05:06 MeteorsHub

@baskrahmer or @Isalia20 mind have look, pls ^^

Borda avatar Jun 27 '25 14:06 Borda

@Borda This is outdated package issue, with the latest lightning and torchmetrics it doesn't reproduce. I tried with lightning==2.1.0 and got this error however with latest lightning and torchmetrics(lightning package was the issue here) this bug is gone so we can close this issue

Isalia20 avatar Jun 28 '25 21:06 Isalia20

@Isalia20 Facing same issue I'm using torchmetrics==1.7.3 And getting error No backend type associated with device type cpu When I'm trying to use MeanAveragePrecision from torchmetrics.detection.mean_ap in DDP training setup

proevgenii avatar Jul 03 '25 21:07 proevgenii

Which lightning version are you using? I suggest to update it to the latest version

Isalia20 avatar Jul 03 '25 21:07 Isalia20

I'm not using lightning at all. Just torchmetrics and torch

proevgenii avatar Jul 04 '25 02:07 proevgenii

I'm not using lightning at all. Just torchmetrics and torch

@proevgenii could you pls share your sample code to reproduce?

Borda avatar Jul 04 '25 08:07 Borda

And the pip list as well please.

Isalia20 avatar Jul 04 '25 08:07 Isalia20

I'm getting the same error. Using Lightning and torchmetric with:

accelerator: auto 
devices: 1  
strategy: ddp_find_unused_parameters_true

Here's the pip list:

pytorchvideo                             0.1.5
torch                                    2.7.1+cu118
torchmetrics                             1.7.3
torchvision                              0.22.1+cu118
lightning                                2.5.2
lightning-utilities                      0.14.3

wiVlad avatar Jul 09 '25 11:07 wiVlad

@Borda @rballeba Sorry for delay with response

my pip list

By running: pip list | grep -e torch -e ligthning I got:

torch                              2.6.0
torchmetrics                       1.7.3
torchvision                        0.21.0

1. Here's code to reproduce:

import os

import torch
import torch.distributed as dist
from torchmetrics.detection.mean_ap import MeanAveragePrecision

# ===============================================================================
# DDP Boilerplate Setup
# ===============================================================================


def ddp_setup():
    """
    Initializes the DDP process group and sets the device for the current process.
    """
    dist.init_process_group(backend="nccl")
    # torchrun sets the LOCAL_RANK environment variable.
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    print(f"Rank {dist.get_rank()}: Initialized on device: cuda:{local_rank}")
    return dist.get_rank(), dist.get_world_size()


def cleanup():
    """Cleans up the DDP process group."""
    dist.destroy_process_group()


# ===============================================================================
# Minimal Logic to Reproduce the Error
# ===============================================================================


def generate_fake_data(rank):
    """
    Generates a small list of fake prediction and ground truth data.
    Crucially, all tensors are explicitly moved to the CPU, mimicking the `preds_postprocess` and `gt_postprocess` methods.
    """
    # Fake Ground Truth
    device = f"cuda:{rank}"
    gt = [
        {
            "boxes": torch.tensor(
                [[21, 32, 101, 120]], dtype=torch.float32, device=device
            ).cpu(),
            "labels": torch.tensor([0], dtype=torch.int64, device=device).cpu(),
        }
    ]

    # Fake Predictions
    preds = [
        {
            "boxes": torch.tensor(
                [[25, 30, 100, 115], [150, 150, 200, 200]],
                dtype=torch.float32,
                device=device,
            ).cpu(),
            "scores": torch.tensor(
                [0.9, 0.85], dtype=torch.float32, device=device
            ).cpu(),
            "labels": torch.tensor([0, 1], dtype=torch.int64, device=device).cpu(),
        }
    ]
    return preds, gt


def run_reproduction_logic(rank: int, world_size: int):
    """
    This function mimics the evaluation flow in a DDP environment.
    """
    print(f"Rank {rank}: Generating fake data...")
    # 1. Each process generates its local data
    local_preds, local_gt = generate_fake_data(rank)

    # Use a barrier to ensure all processes have generated data before gathering.
    dist.barrier()

    if rank == 0:
        print("\n--- Main Process (Rank 0) ---")
        print("Attempting to gather data from all processes...")

    # 2. Gather the lists of CPU tensors from all processes to rank 0.
    #    This is exactly what happens in the DDP evaluate function.
    gathered_preds = [None] * world_size
    gathered_gt = [None] * world_size
    dist.all_gather_object(gathered_preds, local_preds)
    dist.all_gather_object(gathered_gt, local_gt)

    # 3. Only the main process (rank 0) computes the final metric.
    if rank == 0:
        # Combine the lists from all processes
        all_preds = [item for sublist in gathered_preds for item in sublist]
        all_gt = [item for sublist in gathered_gt for item in sublist]

        print(
            f"Data gathered. Total images: {len(all_preds)}. Initializing MeanAveragePrecision metric..."
        )

        # Initialize the metric.
        metric = MeanAveragePrecision(box_format="xyxy")

        # Update the metric with the tensors.
        metric.update(all_preds, all_gt)
        print("Metric updated successfully.")

        print("\n!!! ATTEMPTING TO TRIGGER THE ERROR by calling .compute() !!!")
        try:
            result = metric.compute()
            print("--- UNEXPECTED SUCCESS ---")
            print("The error was not reproduced. Result:", result)
        except Exception as e:
            print("\n---  SUCCESSFULLY REPRODUCED ERROR ---")
            print(f"Caught expected exception: {type(e).__name__}")
            print(f"Error message: {e}")

    dist.barrier()


if __name__ == "__main__":
    rank, world_size = ddp_setup()
    run_reproduction_logic(rank, world_size)
    cleanup()

can be runed by:

torchrun --nproc-per-node=2 reproduce_error.py

2.

Actually I tried to fix it, by avoiding moving to cpu() and left handling multiprocessing to metric. So this is kinda fix:

import os

import torch
import torch.distributed as dist
from torchmetrics.detection.mean_ap import MeanAveragePrecision


def ddp_setup():
    dist.init_process_group(backend="nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    print(f"Rank {dist.get_rank()}: Initialized on device: cuda:{local_rank}")
    return dist.get_rank(), dist.get_world_size()


def cleanup():
    dist.destroy_process_group()


def generate_fake_data(rank):
    """Generates fake data on the specific GPU for the given rank."""
    device = f"cuda:{rank}"
    gt = [
        {
            "boxes": torch.tensor(
                [[21, 32, 101, 120]], dtype=torch.float32, device=device
            ),
            "labels": torch.tensor([0], dtype=torch.int64, device=device),
        }
    ]
    print(gt[0]["boxes"].shape)
    preds = [
        {
            "boxes": torch.randn(
                (100, 4),
                # [[25, 30, 100, 115], [150, 150, 200, 200]],
                dtype=torch.float32,
                device=device,
            ),
            "scores": torch.randn((100), dtype=torch.float32, device=device),
            "labels": torch.tensor(list(range(100)), dtype=torch.int64, device=device),
        }
    ]
    return preds, gt


def run_correct_ddp_logic(rank: int, world_size: int):
    """
    This function demonstrates the correct, non-blocking way to use torchmetrics with DDP.
    """
    main_device = f"cuda:{rank}"
    print(f"Rank {rank}: Generating fake data on {main_device}...")
    local_preds, local_gt = generate_fake_data(rank)

    # --- THE CORRECT DDP PATTERN FOR TORCHMETRICS ---

    # 1. Initialize the metric object on ALL processes.
    metric = MeanAveragePrecision(box_format="xyxy")

    # 2. Update the local metric instance with local data on ALL processes.
    print(f"Rank {rank}: Updating local metric instance.")
    metric.update(local_preds, local_gt)
    dist.barrier()

    print(f"Rank {rank}: Calling metric.compute()...")
    # 3. Call .compute() on ALL processes.
    final_metrics = metric.compute()

    # if rank == 0:
    print("\n--- Main Process (Rank 0) Results ---")
    print("Metric calculation complete.")
    print("Final mAP Result:", final_metrics)

    dist.barrier()


if __name__ == "__main__":
    rank, world_size = ddp_setup()
    run_correct_ddp_logic(rank, world_size)
    cleanup()

To run:

torchrun --nproc-per-node=2  err_fix.py 

I'm not sure if it's correct way to calculate metric. But it gives me some results which are not that far from single gpu setup, and it solves initial error

proevgenii avatar Jul 10 '25 17:07 proevgenii