torchsurv icon indicating copy to clipboard operation
torchsurv copied to clipboard

Some metrics cannot handle GPU tensors

Open jannik-el opened this issue 4 months ago β€’ 3 comments

Libraries Requiring CPU Tensors:

  1. torchsurv.metrics.ConcordanceIndex()
  • Cannot handle GPU tensors
  • Requires manual .cpu() transfer before computation
  1. torchsurv.metrics.Auc()
  • Cannot handle GPU tensors
  • Internal tensor operations create CPU tensors causing device mismatch
  • Requires manual .cpu() transfer before computation

Specific Error Patterns:

  • BatchNorm Error: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
  • Tensor Cat Error: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0!

Minimal Example for testing if metric will work with tensor on GPU

import torch
import numpy as np
from torchsurv.metrics.cindex import ConcordanceIndex
from torchsurv.metrics.auc import Auc

print("πŸ”§ Checking device availability...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create sample survival data
n_samples = 1000
n_features = 10

print("\nπŸ“Š Creating sample survival data...")
# Features (risk scores from a model)
log_hazards = torch.randn(n_samples, 1)
# Events (1 = event occurred, 0 = censored)
events = torch.bernoulli(torch.full((n_samples,), 0.3))  # 30% event rate
# Survival times
times = torch.exponential(torch.ones(n_samples))

print(f"Sample data created: {n_samples} observations, {events.sum().item():.0f} events")

# Move tensors to GPU if available
print(f"\nπŸš€ Moving tensors to {device}...")
log_hazards_gpu = log_hazards.to(device)
events_gpu = events.bool().to(device)
times_gpu = times.to(device)

print(f"βœ… Tensors on device: {log_hazards_gpu.device}")

# Attempt 1: Try TorchSurv metrics with GPU tensors (THIS WILL FAIL)
print("\n❌ ATTEMPT 1: Using TorchSurv metrics directly with GPU tensors")
print("=" * 60)

try:
    print("   πŸ§ͺ Testing ConcordanceIndex with GPU tensors...")
    cindex_metric = ConcordanceIndex()
    cindex = cindex_metric(log_hazards_gpu, events_gpu, times_gpu)
    print(f"   βœ… C-index: {cindex:.4f}")
except Exception as e:
    print(f"   πŸ’₯ ConcordanceIndex FAILED: {type(e).__name__}: {e}")

try:
    print("   πŸ§ͺ Testing AUC with GPU tensors...")
    auc_metric = Auc()
    auc = auc_metric(log_hazards_gpu, events_gpu, times_gpu, new_time=torch.tensor(1.0).to(device))
    print(f"   βœ… AUC: {auc:.4f}")
except Exception as e:
    print(f"   πŸ’₯ AUC FAILED: {type(e).__name__}: {e}")

# Attempt 2: Workaround - move tensors to CPU first (THIS WORKS)
print("\nβœ… ATTEMPT 2: Workaround - moving tensors to CPU first")
print("=" * 60)

print("   ⚠️  Moving tensors from GPU to CPU for TorchSurv compatibility...")
log_hazards_cpu = log_hazards_gpu.cpu()
events_cpu = events_gpu.cpu()
times_cpu = times_gpu.cpu()

try:
    print("   πŸ§ͺ Testing ConcordanceIndex with CPU tensors...")
    cindex_metric = ConcordanceIndex()
    cindex = cindex_metric(log_hazards_cpu, events_cpu, times_cpu)
    print(f"   βœ… C-index: {cindex:.4f}")
except Exception as e:
    print(f"   πŸ’₯ ConcordanceIndex FAILED: {type(e).__name__}: {e}")

try:
    print("   πŸ§ͺ Testing AUC with CPU tensors...")
    auc_metric = Auc()
    auc = auc_metric(log_hazards_cpu, events_cpu, times_cpu, new_time=torch.tensor(1.0))
    print(f"   βœ… AUC: {auc:.4f}")
except Exception as e:
    print(f"   πŸ’₯ AUC FAILED: {type(e).__name__}: {e}")

jannik-el avatar Aug 22 '25 07:08 jannik-el

These two are the only metrics which I am currently working with, it is quite possible that there are more

jannik-el avatar Aug 22 '25 07:08 jannik-el

As context I was going crazy the last 3 days trying to figure out why the heck my CPU overhead was so high, until I realized that the metric processing was being complete on CPU, and not GPU. So my model with only 3000 params was taking 30min to train on a T4 GPU - I don't have time for that haha

jannik-el avatar Aug 22 '25 07:08 jannik-el

Hi @jannik-el ,

Thanks for the issue, I can add GPU compatibility to the metrics, but since they are non-differentiable and should not be used as optimization metrics, we did not expect users to have them on GPUs, but rather at the end for inference w/ CPUs.

I will create a PR and add the new features soon.

The reproducible example is much appreciated!

tcoroller avatar Aug 22 '25 13:08 tcoroller