setfit icon indicating copy to clipboard operation
setfit copied to clipboard

Prediction Interpretability

Open dhkhey opened this issue 1 year ago • 9 comments

Has anyone looked into the interpretability of setfit models? I'm wondering if you can do things like visualizing attention to see which words have the most influence in predicting a particular class. Any direction would be appreciated!

dhkhey avatar Mar 21 '23 16:03 dhkhey

Hello.

I have implemented the integrated gradients method on top of SetFit and it works as one would expect, I find it useful to see what words the models latch on per class and have found some nice insights.

For example, once I was predicting classes based on concatenated text with a "," as the text separator. The model learned to associate lots of "," with one class, which was the correct behaviour.

kgourgou avatar Mar 22 '23 14:03 kgourgou

@kgourgou Hey. That is exactly the behavior I want. Did you use a library for it? I'm trying to use captum but I'm having trouble applying it for the setfit model.

dhkhey avatar Mar 24 '23 01:03 dhkhey

I unfortunately didn't, I hacked together some code to do it, the hardest part was decomposing the model into parts so that I could pass the gradients through but stop just before the positional embeddings are applied (at the time only an sklearn head was allowed, so I had to implement a torch-head as well). It would probably be easier / cleaner to do now as SetFit implements a torch head.

I can upload the code to a public repo if that would be useful to you.

kgourgou avatar Mar 24 '23 10:03 kgourgou

@kgourgou Hi. Yes if you don't mind, seeing your code would be incredibly helpful. I'm new to model interpretability so I'm still trying to learn. I have it working on a simple bert model, but bit lost on passing the embeddings of setfit model. Appreciate it so much!

dhkhey avatar Mar 24 '23 15:03 dhkhey

I'll share here once I have news on this.

kgourgou avatar Mar 26 '23 19:03 kgourgou

I would love to have something similar otherwise I'll work on it and share my try here.

AymericBasset avatar Mar 30 '23 17:03 AymericBasset

Hopefully this will give you some ideas: https://github.com/kgourgou/setfit-integrated-gradients

See demo.ipynb

cc @dhkhey

kgourgou avatar Mar 31 '23 18:03 kgourgou

@kgourgou Thank you so much! I'm sorry I have one more question.
When I try running your code for multiclass classification (3 different labels) I get the following error:

     19         self.linear = torch.nn.Linear(coef.shape[0], 1)
     20         with torch.no_grad():
---> 21             self.linear.weight.copy_(torch.from_numpy(coef))
     22             self.linear.bias.copy_(torch.from_numpy(intercept))
     23 

RuntimeError: output with shape [1] doesn't match the broadcast shape [3]

It works great for binary classification though. Any idea where the issue may be from?

dhkhey avatar Apr 06 '23 17:04 dhkhey

My pleasure!

Yes, there are two issues.

First, the current implementation is a hack, it just copies the learned sklearn head over to an equivalent pytorch head. If your head is already written in pytorch, you don't need to do this conversion, you can just pass gradients through the head as normal. You still need to be able to pass an arbitrary embedding vector (instead of text) and get a class probability though, which is what I tried to do in model_pass() (I think that's in setfit_extensions.py?).

The second issue is that the code is implemented for binary classification. Specifically, I only compute attributions with respect to the 1 class. For multiclass, you probably want to get separate attributions for each possible class. For example, if you are modelling this with a softmax, you would probably want to get attributions for each class probability (an alternative is to only get attributions for the top class). This shouldn't be a hard edit to do.

Hopefully this is somewhat helpful.

kgourgou avatar Apr 06 '23 20:04 kgourgou