Description of the problem
- Make SmoothMPRT faster with torch
Description of a solution
- test and call this function within SmoothMPRT metric
def explain_smooth_batch_torch(
self,
model: ModelInterface,
x_batch: np.ndarray,
y_batch: np.ndarray,
std: float,
**kwargs,
) -> np.ndarray:
"""
Compute explanations, normalise and take absolute (if was configured so during metric initialization.)
This method should primarily be used if you need to generate additional explanation
in metrics body. It encapsulates typical for Quantus pre- and postprocessing approach.
It will do few things:
- call model.shape_input (if ModelInterface instance was provided)
- unwrap model (if ModelInterface instance was provided)
- call explain_func
- expand attribution channel
Parameters
-------
model:
A model that is subject to explanation.
x_batch:
A np.ndarray which contains the input data that are explained.
y_batch:
A np.ndarray which contains the output labels that are explained.
std : float
Standard deviation of the Gaussian noise.
kwargs: optional, dict
List of hyperparameters.
Returns
-------
a_batch:
Batch of explanations ready to be evaluated.
"""
if not isinstance(x_batch, torch.Tensor):
x_batch = torch.Tensor(x_batch).to(self.device)
if not isinstance(y_batch, torch.Tensor):
y_batch = torch.as_tensor(y_batch).to(self.device)
a_batch_smooth = torch.zeros_like(x_batch)
for n in range(self.nr_samples):
# the last epsilon is defined as zero to compute the true output,
# and have SmoothGrad w/ n_iter = 1 === gradient
if n == self.nr_samples - 1:
epsilon = torch.zeros_like(x_batch)
else:
epsilon = torch.randn_like(x_batch) * std
a_batch = quantus.explain(model, x_batch + epsilon, y_batch, **kwargs)
if a_batch_smooth is None:
a_batch_smooth = a_batch / self.nr_samples
else:
a_batch_smooth += a_batch / self.nr_samples
return a_batch_smooth
Minimum acceptance criteria
- Specify what is necessary for the issue to be closed.
- @mentions of the person that is apt to review these changes e.g., @annahedstroem