avalanche icon indicating copy to clipboard operation
avalanche copied to clipboard

endless_cl_sim semantic segmentation accuracy calculation bug

Open ZexinLi0w0 opened this issue 7 months ago • 1 comments

🐛 Describe the bug When working for Endless Continual Learning Simulator, specific for semantic segmentation scenario. Integrating accuracy_metrics in evaluation plugin as

    eval_plugin = EvaluationPlugin(
        accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
        loggers=logger,
    )

It causes crashed bugs due to semantic segmentation requiring pixel level check, which is different from traditional classification.

I check the accuracy calculation in source code, it uses

        # Check if logits or labels
        if len(predicted_y.shape) > 1:
            # Logits -> transform to labels
            predicted_y = torch.max(predicted_y, 1)[1]

        if len(true_y.shape) > 1:
            # Logits -> transform to labels
            true_y = torch.max(true_y, 1)[1]

        true_positives = float(torch.sum(torch.eq(predicted_y, true_y)))
        total_patterns = len(true_y)
        self._mean_accuracy.update(true_positives / total_patterns, total_patterns)

However, this assumes the label is one dimension only. In the semantic segmentation task, accuracy should be calculated per pixel.

An example to show why this does not work:

For training process, the input: predicted_y.shape is [batch_size, num_classes, height, width] true_y.shape is [batch_size, height, width]

This code will change: predicted_y.shape to [batch_size, height, width] true_y.shape to [batch_size, width]

This makes dimension mismatch crash in line

 true_positives = float(torch.sum(torch.eq(predicted_y, true_y)))

🐜 To Reproduce A minimal working example code

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from avalanche.benchmarks.classic import EndlessCLSim
from avalanche.training import Naive
from avalanche.training.plugins import EvaluationPlugin
from avalanche.evaluation.metrics import (
    forgetting_metrics,
    accuracy_metrics,
    loss_metrics,
    ram_usage_metrics,
    timing_metrics,
    MAC_metrics,
)
from avalanche.logging import InteractiveLogger, CSVLogger
from avalanche.models import pytorchcv_wrapper
import argparse
import random
import numpy as np

# Set seeds for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--cuda", type=int, default=0, help="Use CUDA device index; -1 for CPU")
    parser.add_argument("--semseg", action="store_true", default=True, help="Enable semantic segmentation mode")
    parser.add_argument("--dataset_root", type=str, default=".", help="Dataset root")
    parser.add_argument("--scenario", type=str, default="Classes", choices=["Classes", "Illumination", "Weather"])
    parser.add_argument("--training_bs", type=int, default=16)
    parser.add_argument("--eval_bs", type=int, default=16)
    parser.add_argument("--lr", type=float, default=0.001)
    parser.add_argument("--epoch", type=int, default=1)
    args = parser.parse_args()

    device = torch.device(f"cuda:{args.cuda}" if args.cuda != -1 else "cpu")

    # Build model: use resnet20 (from pytorchcv_wrapper) with default settings for cifar10
    # (default backbone is designed for classification; we adjust it for semantic segmentation)
    model = pytorchcv_wrapper.resnet("cifar10", depth=20, pretrained=False)
    model.to(device)

    if args.semseg:
        # For semantic segmentation, we remove the final pooling and replace the classifier head.
        model.final_pool = nn.Identity()
        # Here we assume that the feature extractor outputs features with 64 channels.
        # We replace the classifier head with a segmentation head:
        model.output = nn.Sequential(
            nn.Conv2d(64, 512, kernel_size=3, padding=1),  # Increase feature depth
            nn.ReLU(),
            nn.Conv2d(512, 8, kernel_size=1)  # 8 segmentation classes
        )

        # Override the forward function: extract features, apply segmentation head,
        # and upsample the result to the original input spatial dimensions.
        def _seg_forward(x):
            input_size = x.shape[-2:]  # e.g., (135, 240)
            x = model.features(x)      # features, shape: [N, 64, H_feat, W_feat]
            x = model.final_pool(x)    # Identity (keeps current spatial size)
            x = model.output(x)        # logits, shape: [N, num_classes, H_feat, W_feat]
            x = F.interpolate(x, size=input_size, mode="bilinear", align_corners=False)
            return x

        model.forward = _seg_forward

    # Create the EndlessCLSim benchmark (only semantic segmentation is enabled)
    benchmark = EndlessCLSim(
        scenario=args.scenario,
        sequence_order=None,
        task_order=None,
        semseg=args.semseg,
        dataset_root=args.dataset_root,
    )

    # Retrieve training and testing streams
    train_stream = benchmark.train_stream
    test_stream = benchmark.test_stream

    # Set up optimizer and loss (using CrossEntropyLoss, which expects:
    # model output shape: [N, num_classes, H, W] and target shape: [N, H, W])
    optimizer = Adam(model.parameters(), lr=args.lr)
    criterion = torch.nn.CrossEntropyLoss()

    # Set up loggers (optional)
    interactive_logger = InteractiveLogger()
    csv_logger = CSVLogger("log_semseg.csv")
    logger = [interactive_logger, csv_logger]

    # Set up the evaluation plugin (using the same metrics as in your full code)
    eval_plugin = EvaluationPlugin(
        accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
        loggers=logger,
    )

    # Create the continual learning strategy (Naive in this case)
    cl_strategy = Naive(
        model,
        optimizer,
        criterion,
        train_mb_size=args.training_bs,
        train_epochs=args.epoch,
        eval_mb_size=args.eval_bs,
        device=device,
        evaluator=eval_plugin,
    )

    print("Starting experiment...")
    for experience in train_stream:
        cl_strategy.train(experience)
        res = cl_strategy.eval(test_stream)
        print("Evaluation results:", res)

    print("Experiment completed.")

🐝 Expected behavior For the training process, the input: predicted_y.shape is [batch_size, num_classes, height, width] true_y.shape is [batch_size, height, width]

Accuracy caculation code need change: predicted_y.shape to [batch_size, height, width] true_y.shape to [batch_size, height, width]

For the evaluation process, the input: predicted_y.shape is [batch_size, height, width] true_y.shape is [batch_size, num_classes, height, width]

Accuracy caculation code need change: predicted_y.shape to [batch_size, height, width] true_y.shape to [batch_size, height, width]

Currently, I modify avalanche/evaluation/metrics/accuracy.py to bypass this error:

        '''
        # Check if logits or labels
        if len(predicted_y.shape) > 1:
            # Logits -> transform to labels
            predicted_y = torch.max(predicted_y, 1)[1]
        if len(true_y.shape) > 1:
            # Logits -> transform to labels
            true_y = torch.max(true_y, 1)[1]
        true_positives = float(torch.sum(torch.eq(predicted_y, true_y)))
        total_patterns = len(true_y)
        self._mean_accuracy.update(true_positives / total_patterns, total_patterns)
        '''

        if predicted_y.dim() > 3:
            predicted_y = torch.argmax(predicted_y, dim=1)
        if true_y.dim() > 3:
            true_y = torch.argmax(true_y, dim=1)
        if predicted_y.shape != true_y.shape:
            raise ValueError(f"Size mismatch: predicted_y shape {predicted_y.shape} vs true_y shape {true_y.shape}")

        true_positives = float(torch.sum(torch.eq(predicted_y, true_y)))
        total_patterns = true_y.numel()
        self._mean_accuracy.update(true_positives / total_patterns, total_patterns)

🐞 Screenshots If applicable, add screenshots to help explain your problem.

🦋 Additional context Add any other context about the problem here like your python setup.

ZexinLi0w0 avatar Mar 04 '25 18:03 ZexinLi0w0