Quantus icon indicating copy to clipboard operation
Quantus copied to clipboard

Make SmoothMPRT faster for torch (see implementation)

Open annahedstroem opened this issue 1 year ago • 0 comments

Description of the problem

  • Make SmoothMPRT faster with torch

Description of a solution

  • test and call this function within SmoothMPRT metric
    def explain_smooth_batch_torch(
        self,
        model: ModelInterface,
        x_batch: np.ndarray,
        y_batch: np.ndarray,
        std: float,
        **kwargs,
    ) -> np.ndarray:
        """
        Compute explanations, normalise and take absolute (if was configured so during metric initialization.)
        This method should primarily be used if you need to generate additional explanation
        in metrics body. It encapsulates typical for Quantus pre- and postprocessing approach.
        It will do few things:
            - call model.shape_input (if ModelInterface instance was provided)
            - unwrap model (if ModelInterface instance was provided)
            - call explain_func
            - expand attribution channel

        Parameters
        -------
        model:
            A model that is subject to explanation.
        x_batch:
            A np.ndarray which contains the input data that are explained.
        y_batch:
            A np.ndarray which contains the output labels that are explained.
        std : float
            Standard deviation of the Gaussian noise.
        kwargs: optional, dict
            List of hyperparameters.

        Returns
        -------
        a_batch:
            Batch of explanations ready to be evaluated.
        """
        if not isinstance(x_batch, torch.Tensor):
            x_batch = torch.Tensor(x_batch).to(self.device)

        if not isinstance(y_batch, torch.Tensor):
            y_batch = torch.as_tensor(y_batch).to(self.device)

        a_batch_smooth = torch.zeros_like(x_batch)
        for n in range(self.nr_samples):
            # the last epsilon is defined as zero to compute the true output,
            # and have SmoothGrad w/ n_iter = 1 === gradient
            if n == self.nr_samples - 1:
                epsilon = torch.zeros_like(x_batch)
            else:
                epsilon = torch.randn_like(x_batch) * std

            a_batch = quantus.explain(model, x_batch + epsilon, y_batch, **kwargs)

            if a_batch_smooth is None:
                a_batch_smooth = a_batch / self.nr_samples
            else:
                a_batch_smooth += a_batch / self.nr_samples

        return a_batch_smooth

Minimum acceptance criteria

  • Specify what is necessary for the issue to be closed.
  • @mentions of the person that is apt to review these changes e.g., @annahedstroem

annahedstroem avatar Nov 24 '23 10:11 annahedstroem