transformers-interpret icon indicating copy to clipboard operation
transformers-interpret copied to clipboard

MutliClass True Labels: Bug Fix

Open elemets opened this issue 8 months ago • 0 comments

When visualising the explainability results of a multiclass model I can't work out if there is a way to display the true values correctly. When setting the "true_class" variable in the .visualize() function of the explainer class, it sets every value of True Label to this rather than setting each individual one. Screenshot 2024-06-11 at 10 03 33 AM

I'm assuming this behaviour is because it was designed for binary classification. I can see inside the .visualize() it would be easy to add this as a behaviour: true_class needs to use the index: i in the multi class case. I edited this to make it work.

    def visualize(self, html_filepath: str = None, true_class: str = None):
        """
        Visualizes word attributions. If in a notebook table will be displayed inline.

        Otherwise pass a valid path to `html_filepath` and the visualization will be saved
        as a html file.

        If the true class is known for the text that can be passed to `true_class`

        """
        tokens = [token.replace("Ä ", "") for token in self.decode(self.input_ids)]

        score_viz = [
            self.attributions[i].visualize_attributions(  # type: ignore
                self.pred_probs_list[i],
                "",  # including a predicted class name does not make sense for this explainer
                (
                    "n/a" if not true_class else true_class[i]
                ),  # no true class name for this explainer by default
                self.labels[i],
                tokens,
            )
            for i in range(len(self.attributions))
        ]

These are the results I now get: Screenshot 2024-06-11 at 10 17 24 AM

I hope this is useful or if this solution seems fine we can integrate it. When passing an array to true_class it needs to be wrapped in a list() otherwise it throw an error to do with truth values of numpy arrays.

elemets avatar Jun 11 '24 17:06 elemets