Potential Bug in ROAD during Linear Imputation
Description
When trying to calculate ROAD scores for precomputed explantions i get an error due to a shape mismatch, when the image is tried to be masked/imputed based on the relevance scores. Specifically, i pass x_batch (images) in shape (B, C, W, H) and a_batch (explanations) in shape (B, 1, W, H).
In class ROAD(..).evaluate_batch(..), a_batch is broadcasted to x_batch's shape:
if x_batch.shape != a_batch.shape:
a_batch = np.broadcast_to(a_batch, x_batch.shape)
then a_batch is flattened to shape (B, C*W*H):
# Flatten the attributions.
batch_size = a_batch.shape[0]
a_batch = a_batch.reshape(batch_size, -1)
n_features = a_batch.shape[-1]
Finally, the indices for masking/noisy_linear_imputation are computed from the flattened a_batch across the channel dimension (besides height and width). In other words, the indices relate to color channel values, not pixels:
ordered_indices = np.argsort(-a_batch, axis=1)
Then, the code iterates over batch_images and passes single images of shape (CxWxH) to the perturb_function:
for x_element, top_k_index in zip(x_batch, top_k_indices):
x_perturbed_element = self.perturb_func( # type: ignore
arr=x_element,
indices=top_k_index,
)
The perturb function (noisy_linear_imputation()), however, behaves like the indices passed relate to pixels and not color-channel-values:
arr_flat = arr.reshape((arr.shape[0], -1))
mask = np.ones(arr_flat.shape[1])
mask[indices] = 0
This throws an "index out of bounds error", since the indices passed, that relate to the image with color channels, are too large for indexing pixels independent of color channels.
This is a bug i guess? Also, what behavior is desired? I assume that we'd like to impute pixels rather than color channel values, but that's also not 100% clear to me from the Rong et al. (2022) paper.
Steps to reproduce the behavior
metric = quantus.ROAD(
abs=False,
normalise=True,
disable_warnings=True,
display_progressbar=False,
return_aggregate=False,
)
x_batch = np.random.rand(32, 3, 299, 299)
a_batch = np.random.rand(32, 1, 299, 299)
y_batch = np.random.rand(32)
model = models.resnet50(weights="IMAGENET1K_V2") # from torchvision
scores = metric(
model=model,
x_batch=x_batch,
y_batch=y_batch,
a_batch=a_batch
device=device,
softmax=False
)
Minimum acceptance criteria
Fix the implementation so that pixels are correctly indexed or add an option for deciding whether to impute entire pixels or single color channel values.
- @mentions of the person that is apt to review these changes e.g., @annahedstroem
Best regards Tjade