Quantus icon indicating copy to clipboard operation
Quantus copied to clipboard

Support Model-Agnostic Explainers

Open abarbosa94 opened this issue 1 year ago • 1 comments

Hi there, firstly-- kudos for this amazing project :)

Description of the problem

After reviewing the examples and API usage, I understood that I needed to rely on Tensorflow or Torch. Moreover, even ModelInterface is coupled to Computer Vision-related models.

For example:

class ModelInterface(ABC, Generic[M]):
    """Base ModelInterface for torch and tensorflow models."""

    def __init__(
        self,
        model: M,
        **channel_first: Optional[bool] = True,**
        softmax: bool = False,
        model_predict_kwargs: Optional[Dict[str, Any]] = None,
    ):
        """
        Initialisation of ModelInterface class.

        Parameters
        ----------
        model: torch.nn.Module, tf.keras.Model
            A model this will be wrapped in the ModelInterface:
        channel_first: boolean, optional
             Indicates of the image dimensions are channel first, or channel last. Inferred from the input shape if None.

Also, the explanation_fun is heavily coupled to existing libraries, making it difficult to use if the user wants to rely on a library different than captum; zennit or tf_explain, like shap.

Description of a solution

Before finding this project, I've worked with Captum before. There, it is flexible enough that even though the library provides its own attribution methods, the metrics API lets the user define any explanation method that she/he wants and easily integrates into the model. I only need to return a Tensor of attributions.

Take this as an example:

from captum.metrics import sensitivity_max
import numpy as np
import shap

masker = shap.maskers.Independent(X_test)
def model_log_odds(x):
    log_prob = np.log(model.predict_proba(x) + 1e-20)
    result = log_prob[:, 1] - log_prob[:, 0]
    return result

exact_shap = shap.explainers.ExactExplainer(model_predict, masker)

def shap_exact_function(inputs):
    shap_values = exact_shap(input_tensor.cpu().numpy())
    attributions_shap = torch.tensor(shap_values.values)
    return attributions_shap

**sensitivity_score = sensitivity_max(
    explanation_func=shap_explanation_function,
    inputs=X_test_tensor,
    perturb_function=my_custom_perturbation_fn_generator,  # Adjust as needed
    n_perturb_samples=100,
)**

As I'm new to the project, I'm unsure if Quantum supports such an approach. I tried to look into the docs, but I didn't find it.

The main advantage of such design is that automatically the library becomes more agnostic, making it easy to experiment to models other than images, such as text or tabular; Moreover, it also turn possible to use models that are different than tensorflow or pytorch.

Please let me know what do you think. Thanks!

abarbosa94 avatar Dec 14 '23 19:12 abarbosa94

Hi @abarbosa94, I agree with both statements. Those are known design flaws, and we've talked about them multiple times. Unfortunately, we've never managed to come to conclusion.

aaarrti avatar Feb 19 '24 10:02 aaarrti