captum icon indicating copy to clipboard operation
captum copied to clipboard

DeepLift / DeepLiftShap not converging when `torch.exp` present or multiplications

Open jmschrei opened this issue 2 years ago • 13 comments
trafficstars

🐛 Bug

When torch.exp is present in the model in any form, including softmax, logsoftmax, and logsumexp operations, the deltas seem to get pretty big for my model. I've checked to make sure there isn't an overflow -- most values I'm exp-ing are around -7 to 7.

Separately, I'm finding that element-wise multiplications are producing the same issue.

Neither issue seems to cause a problem for the TFDeepExplainer in the Shap repo, so I don't think there should be a problem doing either, theoretically.

To Reproduce

Steps to reproduce the behavior:

The below models will produce high deltas. They are wrappers for a simple convolution model that produces two outputs. The wrappers mostly extract a single one of those outputs and do simple transformations to get a single value to use.

class Wrapper(torch.nn.Module):
    def __init__(self, model):
        super(Wrapper, self).__init__()
        self.model = model

    def forward(self, X):
        logits = self.model(X)[0][:, 0]
        logits = logits - logits.mean(dim=-1, keepdims=True)        
        y = torch.exp(logits)
        return torch.sum(y, dim=-1)

Here's the multiplication model.

class Wrapper(torch.nn.Module):
    def __init__(self, model):
        super(Wrapper, self).__init__()
        self.model = model
        self.elu = torch.nn.ELU()

    def forward(self, X):
        logits = self.model(X)[0]
        logits = logits - logits.mean(dim=-1, keepdims=True)        
        return torch.sum(logits * logits, dim=-1)

Ultimately, I'm trying to get the following working:

class ProfileWrapper(torch.nn.Module):
    def __init__(self, model):
        super(ProfileWrapper, self).__init__()
        self.model = model

    def forward(self, X):
        logits = self.model(X)[0]
        logits = logits - logits.mean(dim=-1, keepdims=True)
        y = torch.nn.functional.softmax(logits, dim=-1)
        return (y * logits).sum(axis=-1)

Any help getting that working, even if it requires some sort of hack, would be greatly appreciated.

Expected behavior

I would expect to get low deltas for an exp operation.

Environment

Describe the environment used for Captum


 - Captum / PyTorch Version (e.g., 1.0 / 0.4.0): captum=0.5.0, torch=1.11.0
 - OS (e.g., Linux): Linux
 - How you installed Captum / PyTorch (`conda`, `pip`, source): pip
 - Build command you used (if compiling from source):
 - Python version: 3.8.3
 - CUDA/cuDNN version: not relevant
 - GPU models and configuration: not relevant
 - Any other relevant information: not relevant

## Additional context

jmschrei avatar Dec 07 '22 08:12 jmschrei

Thank you for the questions, @jmschrei! I think that we you are seeing this issue because we don't have them in the list of operators whereas TFDeepLift has it. It should be pretty easy to add it. In fact we should probably let users dynamically add those modules.

NarineK avatar Dec 23 '22 05:12 NarineK

@NarineK if I added exp : nonlinear as an entry in that dictionary, would that fix it? Likewise, if I wanted torch.sigmoid to work do I need to also add that, or does adding exp: nonlinear fix that too?

Thanks

jmschrei avatar Jan 26 '23 04:01 jmschrei

In a different context I am noticing, even in captum 0.5.0 (where other of my issues have been resolved), that multiplications seem to be improperly handled.

This works, in the sense of having low convergence deltas:

class GeLU(torch.nn.Module):
	def __init__(self):
		super().__init__()
		self.activation = torch.nn.Sigmoid()

	def forward(self, X):
		return self.activation(1.702 * X)

This does not work:

class GeLU(torch.nn.Module):
	def __init__(self):
		super().__init__()
		self.activation = torch.nn.Sigmoid()

	def forward(self, X):
		return self.activation(1.702 * X) * X

jmschrei avatar Jan 30 '23 19:01 jmschrei

@NarineK sorry to bother you, but any thoughts on the above issue regarding multiplications?

jmschrei avatar Mar 03 '23 17:03 jmschrei

Would be great to fix these issues soon since there is a significant userbase that uses DeepLIFTSHAP from Captum for models in genomics. Thanks in advance for your efforts.

akundaje avatar Mar 13 '23 19:03 akundaje

Hello @NarineK. TfModisCo, the best tool in regulatory genomics to identify transcription factor motifs learned by NNs, won't function well without the DeepLift issue fixed. Right now, if we shift to PyTorch from Tensorflow, we have to sacrifice the best model interpretation tool at our disposal.

muntakimrafi avatar Mar 13 '23 20:03 muntakimrafi

Yes it will be of great help to have this fixed! Thanks in advance!!

panushri25 avatar Mar 14 '23 00:03 panushri25

Would be great to have an update on this issue. Thanks!

anupamajha1 avatar Mar 14 '23 19:03 anupamajha1

Hi developers. Is it possible to get an update on this, or an indication that this is on your radar?

jmschrei avatar Apr 21 '23 17:04 jmschrei

Any news on this bug? We also rely on captum for our research, would appreciate any fix.

gokceneraslan avatar May 23 '23 00:05 gokceneraslan

I don't think Captum is being actively developed anymore.

jmschrei avatar May 23 '23 00:05 jmschrei

Are there similar issues in shap.DeepExplainer too?

gokceneraslan avatar May 23 '23 04:05 gokceneraslan

No there's no problem with the implementation in the original SHAP repo. It's just in Captum.

On Mon, May 22, 2023, 9:26 PM Gökçen Eraslan @.***> wrote:

Are there similar issues in shap.DeepExplainer too?

— Reply to this email directly, view it on GitHub https://github.com/pytorch/captum/issues/1085#issuecomment-1558501200, or unsubscribe https://github.com/notifications/unsubscribe-auth/AABDWENT4643WCJTOEFDOJ3XHQ35XANCNFSM6AAAAAASWQVI3I . You are receiving this because you commented.Message ID: @.***>

akundaje avatar May 23 '23 10:05 akundaje