Quantus icon indicating copy to clipboard operation
Quantus copied to clipboard

Potential Bug in ROAD during Linear Imputation

Open tjades opened this issue 2 months ago • 0 comments

Description

When trying to calculate ROAD scores for precomputed explantions i get an error due to a shape mismatch, when the image is tried to be masked/imputed based on the relevance scores. Specifically, i pass x_batch (images) in shape (B, C, W, H) and a_batch (explanations) in shape (B, 1, W, H).

In class ROAD(..).evaluate_batch(..), a_batch is broadcasted to x_batch's shape:

if x_batch.shape != a_batch.shape:
        a_batch = np.broadcast_to(a_batch, x_batch.shape)

then a_batch is flattened to shape (B, C*W*H):


# Flatten the attributions.
batch_size = a_batch.shape[0]
a_batch = a_batch.reshape(batch_size, -1)
n_features = a_batch.shape[-1]

Finally, the indices for masking/noisy_linear_imputation are computed from the flattened a_batch across the channel dimension (besides height and width). In other words, the indices relate to color channel values, not pixels:

ordered_indices = np.argsort(-a_batch, axis=1)

Then, the code iterates over batch_images and passes single images of shape (CxWxH) to the perturb_function:

for x_element, top_k_index in zip(x_batch, top_k_indices):
    x_perturbed_element = self.perturb_func(  # type: ignore
        arr=x_element,
        indices=top_k_index,
    )

The perturb function (noisy_linear_imputation()), however, behaves like the indices passed relate to pixels and not color-channel-values:

arr_flat = arr.reshape((arr.shape[0], -1))

mask = np.ones(arr_flat.shape[1])
mask[indices] = 0

This throws an "index out of bounds error", since the indices passed, that relate to the image with color channels, are too large for indexing pixels independent of color channels.

This is a bug i guess? Also, what behavior is desired? I assume that we'd like to impute pixels rather than color channel values, but that's also not 100% clear to me from the Rong et al. (2022) paper.

Steps to reproduce the behavior

 metric = quantus.ROAD(
      abs=False,
      normalise=True,
      disable_warnings=True,
      display_progressbar=False,
      return_aggregate=False,
  )

x_batch = np.random.rand(32, 3, 299, 299)
a_batch = np.random.rand(32, 1, 299, 299)
y_batch = np.random.rand(32)
model = models.resnet50(weights="IMAGENET1K_V2") # from torchvision

scores = metric(
    model=model,
    x_batch=x_batch,
    y_batch=y_batch,
    a_batch=a_batch
    device=device,
    softmax=False
)

Minimum acceptance criteria

Fix the implementation so that pixels are correctly indexed or add an option for deciding whether to impute entire pixels or single color channel values.

  • @mentions of the person that is apt to review these changes e.g., @annahedstroem

Best regards Tjade

tjades avatar Oct 31 '25 11:10 tjades