captum icon indicating copy to clipboard operation
captum copied to clipboard

ReLU activations causing divergence for DeepLiftShap

Open jmschrei opened this issue 2 years ago • 6 comments

🐛 Bug

I am observing a large divergence when using DeepLiftShap on a model with ReLU activations (or any type of activation) but not when using torch.nn.Identity instead. This is pretty puzzling to me because I don't think I'm doing anything non-standard, but the attribution results are wacky.

To Reproduce

Steps to reproduce the behavior:

import torch
from captum.attr import DeepLiftShap

activation = torch.nn.Identity

class PermuteFlatten(torch.nn.Module):
	def __init__(self):
		super().__init__()

	def forward(self, X):
		return X.permute(0, 2, 1).reshape(X.shape[0], -1)


class TestModule(torch.nn.Module):
	def __init__(self, n_outputs, k):
		super().__init__()

		self.activation = activation()

		self.conv = torch.nn.Conv1d(4, k, kernel_size=17, padding='same', bias=False)
		self.bn = torch.nn.BatchNorm1d(k, eps=0.001)
		self.mp = torch.nn.MaxPool1d(3)
		self.activation1 = activation()

		self.conv1 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
		self.bn1 = torch.nn.BatchNorm1d(k, eps=0.001)
		self.mp1 = torch.nn.MaxPool1d(2)
		self.activation2 = activation()

		self.conv2 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
		self.bn2 = torch.nn.BatchNorm1d(k, eps=0.001)
		self.mp2 = torch.nn.MaxPool1d(2)
		self.activation3 = activation()

		self.conv3 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
		self.bn3 = torch.nn.BatchNorm1d(k, eps=0.001)
		self.mp3 = torch.nn.MaxPool1d(2)
		self.activation4 = activation()

		self.conv4 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
		self.bn4 = torch.nn.BatchNorm1d(k, eps=0.001)
		self.mp4 = torch.nn.MaxPool1d(2)
		self.activation5 = activation()

		self.conv5 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
		self.bn5 = torch.nn.BatchNorm1d(k, eps=0.001)
		self.mp5 = torch.nn.MaxPool1d(2)
		self.activation6 = activation()

		self.conv6 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
		self.bn6 = torch.nn.BatchNorm1d(k, eps=0.001)
		self.mp6 = torch.nn.MaxPool1d(2)
		self.activation7 = activation()

		self.conv7 = torch.nn.Conv1d(k, k, kernel_size=1, bias=False)
		self.bn7 = torch.nn.BatchNorm1d(k, eps=0.001)
		self.activation8 = activation()

		self.pf = PermuteFlatten() 

		self.fc = torch.nn.Linear(k*(1344 // (3 * 2 ** 6)), 32, bias=False)
		self.bn8 = torch.nn.BatchNorm1d(32, eps=0.001)
		self.activation9 = activation()
		self.fc1 = torch.nn.Linear(32, n_outputs)
																								  
	def forward(self, X):
		X = self.activation(X)
		X = self.activation1(self.mp(self.bn(self.conv(X))))
		X = self.activation2(self.mp1(self.bn1(self.conv1(X))))
		X = self.activation3(self.mp2(self.bn2(self.conv2(X))))
		X = self.activation4(self.mp3(self.bn3(self.conv3(X))))
		X = self.activation5(self.mp4(self.bn4(self.conv4(X))))
		X = self.activation6(self.mp5(self.bn5(self.conv5(X))))
		X = self.activation7(self.mp6(self.bn6(self.conv6(X))))		
		X = self.activation8(self.bn7(self.conv7(X)))

		X = self.pf(X)

		X = self.activation9(self.bn8(self.fc(X)))
		X = self.fc1(X)
		return X

model = TestModule(20, 288).double().eval()

X = torch.randn(1, 4, 1344).double() * 1000
reference = torch.randn(10, 4, 1344).double() * 1000

dl = DeepLiftShap(model)
X_attr, delta = dl.attribute(X, reference, target=0, return_convergence_delta=True)

y = model(X)
y_ref = model(reference)

print(y[:,0] - y_ref[:,0])
print(X_attr.sum())
print()
print(delta)

This should return:

tensor([-0.6306,  0.3156,  0.1099,  0.3692,  0.2338,  1.4345,  0.1836, -0.1585,
        -1.2692, -1.3378], dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(-0.0749, dtype=torch.float64, grad_fn=<SumBackward0>)

tensor([ 3.3307e-16, -2.4980e-15,  5.9119e-15, -8.3267e-16, -2.3315e-15,
        -7.9936e-15, -1.0769e-14, -2.1372e-15,  1.1546e-14, -7.7716e-15],
       dtype=torch.float64)

If you change the activation to a ReLU:

activation = torch.nn.ReLU

Then you get:

tensor([ 0.0661,  0.0156,  0.0222,  0.0308, -0.0484,  0.0219, -0.0122,  0.0256,
         0.0703, -0.0111], dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.0348, dtype=torch.float64, grad_fn=<SumBackward0>)

tensor([ 0.1226, -0.0261, -0.0167, -0.0194, -0.0280,  0.0237, -0.0341, -0.0470,
         0.1896,  0.0032], dtype=torch.float64)

HOWEVER:

If you change the last two activations to be ReLUs manually, while keeping the rest as Identity, it converges. Using

class TestModule(torch.nn.Module):
	def __init__(self, n_outputs, k):
		super().__init__()

		self.activation = torch.nn.ReLU()

		self.conv = torch.nn.Conv1d(4, k, kernel_size=17, padding='same', bias=False)
		self.bn = torch.nn.BatchNorm1d(k, eps=0.001)
		self.mp = torch.nn.MaxPool1d(3)
		self.activation1 = activation()

		self.conv1 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
		self.bn1 = torch.nn.BatchNorm1d(k, eps=0.001)
		self.mp1 = torch.nn.MaxPool1d(2)
		self.activation2 = activation()

		self.conv2 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
		self.bn2 = torch.nn.BatchNorm1d(k, eps=0.001)
		self.mp2 = torch.nn.MaxPool1d(2)
		self.activation3 = activation()

		self.conv3 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
		self.bn3 = torch.nn.BatchNorm1d(k, eps=0.001)
		self.mp3 = torch.nn.MaxPool1d(2)
		self.activation4 = activation()

		self.conv4 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
		self.bn4 = torch.nn.BatchNorm1d(k, eps=0.001)
		self.mp4 = torch.nn.MaxPool1d(2)
		self.activation5 = activation()

		self.conv5 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
		self.bn5 = torch.nn.BatchNorm1d(k, eps=0.001)
		self.mp5 = torch.nn.MaxPool1d(2)
		self.activation6 = activation()

		self.conv6 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
		self.bn6 = torch.nn.BatchNorm1d(k, eps=0.001)
		self.mp6 = torch.nn.MaxPool1d(2)
		self.activation7 = activation()

		self.conv7 = torch.nn.Conv1d(k, k, kernel_size=1, bias=False)
		self.bn7 = torch.nn.BatchNorm1d(k, eps=0.001)
		self.activation8 = torch.nn.ReLU()

		self.pf = PermuteFlatten() 

		self.fc = torch.nn.Linear(k*(1344 // (3 * 2 ** 6)), 32, bias=False)
		self.bn8 = torch.nn.BatchNorm1d(32, eps=0.001)
		self.activation9 = activation()
		self.fc1 = torch.nn.Linear(32, n_outputs)
																								  
	def forward(self, X):
		X = self.activation(X)
		X = self.activation1(self.mp(self.bn(self.conv(X))))
		X = self.activation2(self.mp1(self.bn1(self.conv1(X))))
		X = self.activation3(self.mp2(self.bn2(self.conv2(X))))
		X = self.activation4(self.mp3(self.bn3(self.conv3(X))))
		X = self.activation5(self.mp4(self.bn4(self.conv4(X))))
		X = self.activation6(self.mp5(self.bn5(self.conv5(X))))
		X = self.activation7(self.mp6(self.bn6(self.conv6(X))))		
		X = self.activation8(self.bn7(self.conv7(X)))

		X = self.pf(X)

		X = self.activation9(self.bn8(self.fc(X)))
		X = self.fc1(X)
		return X

returns

tensor([-1.2153, -0.2924, -0.5858, -0.2723, -0.0496, -0.7380, -0.7172, -1.2634,
        -0.5631, -0.9969], dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(-0.6694, dtype=torch.float64, grad_fn=<SumBackward0>)

tensor([-1.9984e-15,  5.6621e-15,  4.4409e-16,  6.6613e-16,  2.6437e-15,
         5.2180e-15, -4.4409e-16,  0.0000e+00,  3.8858e-15,  3.3307e-16],
       dtype=torch.float64)

Expected behavior

I would expect the convergence delta to be within machine precision of 0 regardless of the number of ReLU layers.

Environment

Describe the environment used for Captum

  • Captum / PyTorch Version (e.g., 1.0 / 0.4.0): 0.6.0 / 1.13.1, respectively
  • OS (e.g., Linux): Linux
  • How you installed Captum / PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.9.13

Additional context

jmschrei avatar Jan 26 '23 07:01 jmschrei

@jmschrei, sorry that you are seeing this issue. I think that this is similar to the old issues that you were having. The problem is the reuse of the RELU because we are re-writing the gradients for RELUs but we are not re-writing the gradients for the Identity. https://github.com/pytorch/captum/blob/master/captum/attr/_core/deep_lift.py#L1045

In PyTorch when the hook a module we don't know where in execution graph it is called. If we reuse a module such as RELU and hook it then the grad input and output get messed up. @vivekmig and I discussed to fix this by counting the calls to the same module. The new versions of PyTorch might have other solutions for this that we need to look into.

If you change it to the below code where you don't reuse the RELUs you'll see small deltas

import torch
from captum.attr import DeepLiftShap

#activation = torch.nn.Identity

class PermuteFlatten(torch.nn.Module):
	def __init__(self):
		super().__init__()

	def forward(self, X):
		return X.permute(0, 2, 1).reshape(X.shape[0], -1)


class TestModule(torch.nn.Module):
	def __init__(self, n_outputs, k):
		super().__init__()

		self.activation = torch.nn.ReLU()

		self.conv = torch.nn.Conv1d(4, k, kernel_size=17, padding='same', bias=False)
		self.bn = torch.nn.BatchNorm1d(k, eps=0.001)
		self.mp = torch.nn.MaxPool1d(3)
		self.activation1 = torch.nn.ReLU()

		self.conv1 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
		self.bn1 = torch.nn.BatchNorm1d(k, eps=0.001)
		self.mp1 = torch.nn.MaxPool1d(2)
		self.activation2 = torch.nn.ReLU()

		self.conv2 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
		self.bn2 = torch.nn.BatchNorm1d(k, eps=0.001)
		self.mp2 = torch.nn.MaxPool1d(2)
		self.activation3 = torch.nn.ReLU()

		self.conv3 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
		self.bn3 = torch.nn.BatchNorm1d(k, eps=0.001)
		self.mp3 = torch.nn.MaxPool1d(2)
		self.activation4 = torch.nn.ReLU()

		self.conv4 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
		self.bn4 = torch.nn.BatchNorm1d(k, eps=0.001)
		self.mp4 = torch.nn.MaxPool1d(2)
		self.activation5 = torch.nn.ReLU()

		self.conv5 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
		self.bn5 = torch.nn.BatchNorm1d(k, eps=0.001)
		self.mp5 = torch.nn.MaxPool1d(2)
		self.activation6 = torch.nn.ReLU()

		self.conv6 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
		self.bn6 = torch.nn.BatchNorm1d(k, eps=0.001)
		self.mp6 = torch.nn.MaxPool1d(2)
		self.activation7 = torch.nn.ReLU()

		self.conv7 = torch.nn.Conv1d(k, k, kernel_size=1, bias=False)
		self.bn7 = torch.nn.BatchNorm1d(k, eps=0.001)
		self.activation8 = torch.nn.ReLU()

		self.pf = PermuteFlatten() 

		self.fc = torch.nn.Linear(k*(1344 // (3 * 2 ** 6)), 32, bias=False)
		self.bn8 = torch.nn.BatchNorm1d(32, eps=0.001)
		self.activation9 = torch.nn.ReLU()
		self.fc1 = torch.nn.Linear(32, n_outputs)
																								  
	def forward(self, X):
		X = self.activation(X)
		X = self.activation1(self.mp(self.bn(self.conv(X))))
		X = self.activation2(self.mp1(self.bn1(self.conv1(X))))
		X = self.activation3(self.mp2(self.bn2(self.conv2(X))))
		X = self.activation4(self.mp3(self.bn3(self.conv3(X))))
		X = self.activation5(self.mp4(self.bn4(self.conv4(X))))
		X = self.activation6(self.mp5(self.bn5(self.conv5(X))))
		X = self.activation7(self.mp6(self.bn6(self.conv6(X))))		
		X = self.activation8(self.bn7(self.conv7(X)))

		X = self.pf(X)

		X = self.activation9(self.bn8(self.fc(X)))
		X = self.fc1(X)
		return X

model = TestModule(20, 288).double().eval()

X = torch.randn(1, 4, 1344).double() * 1000
reference = torch.randn(10, 4, 1344).double() * 1000

dl = DeepLiftShap(model)
X_attr, delta = dl.attribute(X, reference, target=0, return_convergence_delta=True)

y = model(X)
y_ref = model(reference)

print(y[:,0] - y_ref[:,0])
print(X_attr.sum())
print()
print(delta)

NarineK avatar Jan 30 '23 05:01 NarineK

Hi @NarineK , thanks for getting back to me. When I explicitly call torch.nn.ReLU() for each activation I don't see much difference. Is it working on your end?

Your code:

y[:,0] - y_ref[:,0]: tensor([-0.0183, -0.0890, -0.0488, -0.0487, -0.0483,  0.0016, -0.0432, -0.0276,
        -0.1023, -0.0659], dtype=torch.float64, grad_fn=<SubBackward0>)
X_attr.sum(): tensor(-0.0518, dtype=torch.float64, grad_fn=<SumBackward0>)

deltas: tensor([-0.0475,  0.0608, -0.1042, -0.0746, -0.0878,  0.1340,  0.1053,  0.0663,
         0.0264, -0.1065], dtype=torch.float64)

My code but with using ReLU instead of Identity:

y[:,0] - y_ref[:,0]: tensor([ 0.0837, -0.0177,  0.0327,  0.0403,  0.0081,  0.1140,  0.0547,  0.0875,
         0.0992,  0.0937], dtype=torch.float64, grad_fn=<SubBackward0>)
X_attr.sum(): tensor(0.1714, dtype=torch.float64, grad_fn=<SumBackward0>)

deltas: tensor([ 0.0173, -0.0053,  0.0902,  0.2354,  0.1151,  0.1205,  0.0132,  0.1550,
         0.2431,  0.1332], dtype=torch.float64)

My code using Identity:

y[:,0] - y_ref[:,0]: tensor([-3.4017, -5.3239, -5.5095, -4.8942, -2.2603, -3.9636, -5.7602, -2.6764,
        -4.1172, -5.0349], dtype=torch.float64, grad_fn=<SubBackward0>)
X_attr.sum(): tensor(-4.2942, dtype=torch.float64, grad_fn=<SumBackward0>)

deltas: tensor([-5.3291e-15, -5.3291e-15, -2.1316e-14, -6.2172e-15,  2.6645e-15,
         3.1086e-15,  2.6645e-15, -4.4409e-15, -7.9936e-15, -1.0658e-14],
       dtype=torch.float64)

jmschrei avatar Jan 30 '23 17:01 jmschrei

Hi, I also tried running the code on my side, seems all of them produce delta pretty close to 0.

NarineK's:

delta=tensor([-3.6082e-16,  5.8981e-17,  6.3144e-16,  2.1858e-16,  2.1858e-16,
        -2.4460e-16,  1.7347e-17,  2.4286e-17,  8.3267e-17,  4.5797e-16],
       dtype=torch.float64)

jmschrei's: ReLU

delta=tensor([-6.3144e-16,  1.1796e-16,  2.0817e-17,  8.3267e-17, -9.7145e-17,
        -2.9143e-16,  2.4460e-16, -1.5266e-16, -1.1102e-16,  8.3267e-17],
       dtype=torch.float64)

Identity

tensor([ 3.2196e-15,  5.1070e-15,  1.8874e-15,  6.2172e-15,  7.3275e-15,
        -2.2204e-15, -2.2204e-16,  9.5479e-15,  1.2212e-14,  7.5495e-15],
       dtype=torch.float64)

tested on pytorch 1.11.0 and captum 0.5.0

yztxwd avatar Jan 30 '23 18:01 yztxwd

Great observation @yztxwd. If I downgrade to captum==0.5.0, it seems to work.

NarineK's

        -3.2650e-05,  2.4080e-02,  5.3694e-02, -4.3929e-02, -1.2449e-02],
       dtype=torch.float64, grad_fn=<SubBackward0>)
X_attr.sum():  tensor(0.0158, dtype=torch.float64, grad_fn=<SumBackward0>)

delta:  tensor([-2.9143e-16,  1.9429e-16, -2.9143e-16, -6.5919e-17, -1.8041e-16,
         4.1113e-16,  6.8695e-16, -3.8858e-16,  2.8449e-16,  7.1991e-16],
       dtype=torch.float64)

mine with ReLUs

y[:,0] - y_ref[:,0]:  tensor([ 0.0022,  0.0741, -0.0473, -0.0717,  0.0102, -0.0224, -0.0175,  0.0131,
        -0.0086,  0.0529], dtype=torch.float64, grad_fn=<SubBackward0>)
X_attr.sum():  tensor(-0.0015, dtype=torch.float64, grad_fn=<SumBackward0>)

delta:  tensor([ 5.3429e-16,  4.4409e-16,  1.1102e-16,  5.2736e-16, -4.5103e-16,
        -6.9389e-18,  2.4286e-16,  7.2858e-17, -4.1980e-16, -2.0817e-16],
       dtype=torch.float64)

However, I'm noticing that there are more warning messages with 0.5.0. Specifically, there are two additional ones related to max pooling being an "invalid module". Is it possible that whatever fix was implemented for this has a bug in it?

/users/jmschr/anaconda3/lib/python3.9/site-packages/captum/attr/_core/deep_lift.py:336: UserWarning: Setting forward, backward hooks and attributes on non-linear
               activations. The hooks and attributes will be removed
            after the attribution is finished
  warnings.warn(
/users/jmschr/anaconda3/lib/python3.9/site-packages/captum/attr/_core/deep_lift.py:467: UserWarning: An invalid module MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False) is detected. Saved gradients will
                be used as the gradients of the module's input tensor.
                See MaxPool1d as an example.
  warnings.warn(
/users/jmschr/anaconda3/lib/python3.9/site-packages/captum/attr/_core/deep_lift.py:467: UserWarning: An invalid module MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) is detected. Saved gradients will
                be used as the gradients of the module's input tensor.
                See MaxPool1d as an example.
  warnings.warn(

jmschrei avatar Jan 30 '23 18:01 jmschrei

Yes, I agree, that it is a PT version related. I'm using the master version of PT which is 2.0.0a0 version. I think that you shouldn't worry about the warnings. There was a bug in early version of PT which we fixed and send out that warning. https://github.com/pytorch/captum/commit/288cd3a6754d85cbcb0ce74784aa014876284b6c#diff-536b508d8c8a83f406729a6bd359e07d9798cc922d3ad3bf2d3501dde23033a4L467 @vivekmig switched to tensor hooks and I think that there might be an issue with that for your version, @jmschrei! Let me ping, Vivek! Tensor hooks are more reliable than module hooks that's why we switched to them in captum's latest version but there might be a bug related to it. We'll debug and fix it. Thank you for bringing it up.

NarineK avatar Jan 31 '23 04:01 NarineK

Sounds good. Thanks for all your quick responses.

jmschrei avatar Jan 31 '23 05:01 jmschrei