captum
captum copied to clipboard
DeepLift / DeepLiftShap not converging when `torch.exp` present or multiplications
🐛 Bug
When torch.exp is present in the model in any form, including softmax, logsoftmax, and logsumexp operations, the deltas seem to get pretty big for my model. I've checked to make sure there isn't an overflow -- most values I'm exp-ing are around -7 to 7.
Separately, I'm finding that element-wise multiplications are producing the same issue.
Neither issue seems to cause a problem for the TFDeepExplainer in the Shap repo, so I don't think there should be a problem doing either, theoretically.
To Reproduce
Steps to reproduce the behavior:
The below models will produce high deltas. They are wrappers for a simple convolution model that produces two outputs. The wrappers mostly extract a single one of those outputs and do simple transformations to get a single value to use.
class Wrapper(torch.nn.Module):
def __init__(self, model):
super(Wrapper, self).__init__()
self.model = model
def forward(self, X):
logits = self.model(X)[0][:, 0]
logits = logits - logits.mean(dim=-1, keepdims=True)
y = torch.exp(logits)
return torch.sum(y, dim=-1)
Here's the multiplication model.
class Wrapper(torch.nn.Module):
def __init__(self, model):
super(Wrapper, self).__init__()
self.model = model
self.elu = torch.nn.ELU()
def forward(self, X):
logits = self.model(X)[0]
logits = logits - logits.mean(dim=-1, keepdims=True)
return torch.sum(logits * logits, dim=-1)
Ultimately, I'm trying to get the following working:
class ProfileWrapper(torch.nn.Module):
def __init__(self, model):
super(ProfileWrapper, self).__init__()
self.model = model
def forward(self, X):
logits = self.model(X)[0]
logits = logits - logits.mean(dim=-1, keepdims=True)
y = torch.nn.functional.softmax(logits, dim=-1)
return (y * logits).sum(axis=-1)
Any help getting that working, even if it requires some sort of hack, would be greatly appreciated.
Expected behavior
I would expect to get low deltas for an exp operation.
Environment
Describe the environment used for Captum
- Captum / PyTorch Version (e.g., 1.0 / 0.4.0): captum=0.5.0, torch=1.11.0
- OS (e.g., Linux): Linux
- How you installed Captum / PyTorch (`conda`, `pip`, source): pip
- Build command you used (if compiling from source):
- Python version: 3.8.3
- CUDA/cuDNN version: not relevant
- GPU models and configuration: not relevant
- Any other relevant information: not relevant
## Additional context
Thank you for the questions, @jmschrei! I think that we you are seeing this issue because we don't have them in the list of operators whereas TFDeepLift has it. It should be pretty easy to add it. In fact we should probably let users dynamically add those modules.
@NarineK if I added exp : nonlinear as an entry in that dictionary, would that fix it? Likewise, if I wanted torch.sigmoid to work do I need to also add that, or does adding exp: nonlinear fix that too?
Thanks
In a different context I am noticing, even in captum 0.5.0 (where other of my issues have been resolved), that multiplications seem to be improperly handled.
This works, in the sense of having low convergence deltas:
class GeLU(torch.nn.Module):
def __init__(self):
super().__init__()
self.activation = torch.nn.Sigmoid()
def forward(self, X):
return self.activation(1.702 * X)
This does not work:
class GeLU(torch.nn.Module):
def __init__(self):
super().__init__()
self.activation = torch.nn.Sigmoid()
def forward(self, X):
return self.activation(1.702 * X) * X
@NarineK sorry to bother you, but any thoughts on the above issue regarding multiplications?
Would be great to fix these issues soon since there is a significant userbase that uses DeepLIFTSHAP from Captum for models in genomics. Thanks in advance for your efforts.
Hello @NarineK. TfModisCo, the best tool in regulatory genomics to identify transcription factor motifs learned by NNs, won't function well without the DeepLift issue fixed. Right now, if we shift to PyTorch from Tensorflow, we have to sacrifice the best model interpretation tool at our disposal.
Yes it will be of great help to have this fixed! Thanks in advance!!
Would be great to have an update on this issue. Thanks!
Hi developers. Is it possible to get an update on this, or an indication that this is on your radar?
Any news on this bug? We also rely on captum for our research, would appreciate any fix.
I don't think Captum is being actively developed anymore.
Are there similar issues in shap.DeepExplainer too?
No there's no problem with the implementation in the original SHAP repo. It's just in Captum.
On Mon, May 22, 2023, 9:26 PM Gökçen Eraslan @.***> wrote:
Are there similar issues in shap.DeepExplainer too?
— Reply to this email directly, view it on GitHub https://github.com/pytorch/captum/issues/1085#issuecomment-1558501200, or unsubscribe https://github.com/notifications/unsubscribe-auth/AABDWENT4643WCJTOEFDOJ3XHQ35XANCNFSM6AAAAAASWQVI3I . You are receiving this because you commented.Message ID: @.***>