Quantus icon indicating copy to clipboard operation
Quantus copied to clipboard

[bug] `quantus.evaluate` with method `GradCAM` gives error for common input

Open vedal opened this issue 2 years ago • 2 comments

In the current version quantus==0.1.4, the following:

import torch
from torch import nn
from torchvision import models
import quantus

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        resnet18 = models.resnet18()
        children = list(resnet18.children())

        self.backbone = nn.Sequential(*children[:-2])
        self.head = nn.Sequential(
            *children[-2:-1],
            nn.Flatten(start_dim=1),
            children[-1]
        )

    def forward(self, batch):
        return self.head(self.backbone(batch))


model = Net()

x_batch = torch.rand(1, 3, 256, 256)
y_batch = [1]

a_batch = quantus.explain(
                model=model,
                inputs=x_batch,
                targets=y_batch,
                method='GradCAM',
                gc_layer=list(model.named_modules())[-6][1],
                normalise=True,
)


quantus.evaluate(
    metrics={
        "PointingGame": quantus.PointingGame(disable_warnings=True),
    },
    xai_methods={"GradCAM": a_batch},
    model=model.head,
    x_batch=model.backbone(x_batch),
    y_batch=y_batch,
    s_batch=np.ones(shape=(1, 1, 8, 8)),
    agg_func=np.mean,
    **{"explain_func": quantus.explain}
)

produces the error

ValueError: Ambiguous input shape. Cannot infer channel-first/channel-last order.

This error makes sense for attributions at the input layer, but is probably unintended behavior for GradCAM, as GradCAM is usually applied to an intermediate conv-layer.

vedal avatar May 30 '22 08:05 vedal

Thanks so much @vedal I'm looking into this and will get back to you!

annahedstroem avatar Jun 02 '22 13:06 annahedstroem

@annahedstroem Maybe it should be possible to pass parameter gc_layer to quantus.evaluate directly when using the GradCam method, like what is done in quantus.explain ? This would bypass the need to split up the model into backbone and head. In the end, I suppose its a library design issue, whether to keep evaluate simple or increase flexibility

vedal avatar Jun 03 '22 09:06 vedal

@annahedstroem Maybe it should be possible to pass parameter gc_layer to quantus.evaluate directly when using the GradCam method, like what is done in quantus.explain ? This would bypass the need to split up the model into backbone and head. In the end, I suppose its a library design issue, whether to keep evaluate simple or increase flexibility

Thanks for your feedback, @vedal and sorry for the long wait. You don't need model.backbone, as you should just pass x_batch=x_batch to the metric. As to the attribution dimensions, Quantus requires the non-chanel dimensions of input (x_batch) and (a_batch) to match. To achieve this with GradCAM you can pass the interpolate arguments as in the following corrected version of your code:

import torch
from torch import nn
from torchvision import models
import quantus
import numpy as np

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        resnet18 = models.resnet18()
        children = list(resnet18.children())

        self.backbone = nn.Sequential(*children[:-2])
        self.head = nn.Sequential(
            *children[-2:-1],
            nn.Flatten(start_dim=1),
            children[-1]
        )

    def forward(self, batch):
        return self.head(self.backbone(batch))


model = Net()

x_batch = torch.rand(1, 3, 256, 256)
y_batch = [1]

a_batch = quantus.explain(
                model=model,
                inputs=x_batch,
                targets=y_batch,
                method='GradCAM',
                interpolate=x_batch.shape[2:],
                interpolate_method="nearest",
                gc_layer=list(model.named_modules())[-6][1],
                normalise=True,
)


quantus.evaluate(
    metrics={
        "PointingGame": quantus.PointingGame(disable_warnings=True),
    },
    xai_methods={"GradCAM": a_batch},
    model=model.head,
    x_batch=x_batch,
    y_batch=y_batch,
    s_batch=np.ones(shape=(1, 1, 256, 256)),
    agg_func=np.mean,
    **{"explain_func": quantus.explain}
)

To learn more about how interpolation is performed, please refer here. The following pull request has been added to explicitly warn the user of this issue and the need to pass the interpolate argument when using GradCAM.

dilyabareeva avatar Dec 06 '22 13:12 dilyabareeva

hi @dilyabareeva , thanks for looking into this! Yes, looking into the source code, I noticed that several parameters like model weren't actually used when feeding in a_batch. In the following code, I compute attributions using quantus.explain and evalute using quantus.evaluate. Note how I need to feed in five (5!) empty value arguments to run quantus.evaluate:

a_batch = quantus.explain(
    model=model,
    inputs=x_batch,
    targets=y_batch,
    gc_layer=list(model.named_modules())[-3][1],
    method="GradCam",
)


results = quantus.evaluate(
    metrics={"PointingGame": quantus.PointingGame()},
    a_batch=a_batch,
    s_batch=s_batch,
    channel_first=True,

    # the five (5) parameters below are unused by required by Quantus when arg a_batch is used.
    model=None,
    x_batch=np.empty_like(a_batch),  # z_batch,
    y_batch=None,
    xai_methods=["xai_name"],
    explain_func_kwargs=dict(),
)["xai_name"]

Suggestion: I suggest a re-write of quantus.evaluate that doesn't allow for re-computing a_batch, since this can be done in quantus.explain, particularly due to the name "evaluate", which suggests inputs are being evaluated (not computed from the start).

Note my general suggestion in #186

vedal avatar Dec 09 '22 08:12 vedal

@vedal Thanks for your feedback! The purpose of quantus.evaluate is to compare multiple XAI methods by using multiple metrics, hence the xai_methods and metrics dictionaries. While it is possible to pass pre-calculated attributions to quantus.evaluate as values in xai_methods dictionary, model, x_batch and y_batch arguments are still necessary for the calculation of metrics themselves. You are welcome to review the pending updates in related PR #214 (issue #199) and leave comments if you see fit. Closing this issue as the GradCam issue has been addressed in the merged PR #205.

dilyabareeva avatar Jan 20 '23 22:01 dilyabareeva