transformers-interpret
transformers-interpret copied to clipboard
MutliClass True Labels: Bug Fix
When visualising the explainability results of a multiclass model I can't work out if there is a way to display the true values correctly. When setting the "true_class" variable in the .visualize() function of the explainer class, it sets every value of True Label to this rather than setting each individual one.
I'm assuming this behaviour is because it was designed for binary classification. I can see inside the .visualize() it would be easy to add this as a behaviour: true_class needs to use the index: i in the multi class case. I edited this to make it work.
def visualize(self, html_filepath: str = None, true_class: str = None):
"""
Visualizes word attributions. If in a notebook table will be displayed inline.
Otherwise pass a valid path to `html_filepath` and the visualization will be saved
as a html file.
If the true class is known for the text that can be passed to `true_class`
"""
tokens = [token.replace("Ä ", "") for token in self.decode(self.input_ids)]
score_viz = [
self.attributions[i].visualize_attributions( # type: ignore
self.pred_probs_list[i],
"", # including a predicted class name does not make sense for this explainer
(
"n/a" if not true_class else true_class[i]
), # no true class name for this explainer by default
self.labels[i],
tokens,
)
for i in range(len(self.attributions))
]
These are the results I now get:
I hope this is useful or if this solution seems fine we can integrate it. When passing an array to true_class it needs to be wrapped in a list() otherwise it throw an error to do with truth values of numpy arrays.