captum
captum copied to clipboard
ReLU activations causing divergence for DeepLiftShap
🐛 Bug
I am observing a large divergence when using DeepLiftShap on a model with ReLU activations (or any type of activation) but not when using torch.nn.Identity instead. This is pretty puzzling to me because I don't think I'm doing anything non-standard, but the attribution results are wacky.
To Reproduce
Steps to reproduce the behavior:
import torch
from captum.attr import DeepLiftShap
activation = torch.nn.Identity
class PermuteFlatten(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, X):
return X.permute(0, 2, 1).reshape(X.shape[0], -1)
class TestModule(torch.nn.Module):
def __init__(self, n_outputs, k):
super().__init__()
self.activation = activation()
self.conv = torch.nn.Conv1d(4, k, kernel_size=17, padding='same', bias=False)
self.bn = torch.nn.BatchNorm1d(k, eps=0.001)
self.mp = torch.nn.MaxPool1d(3)
self.activation1 = activation()
self.conv1 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
self.bn1 = torch.nn.BatchNorm1d(k, eps=0.001)
self.mp1 = torch.nn.MaxPool1d(2)
self.activation2 = activation()
self.conv2 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
self.bn2 = torch.nn.BatchNorm1d(k, eps=0.001)
self.mp2 = torch.nn.MaxPool1d(2)
self.activation3 = activation()
self.conv3 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
self.bn3 = torch.nn.BatchNorm1d(k, eps=0.001)
self.mp3 = torch.nn.MaxPool1d(2)
self.activation4 = activation()
self.conv4 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
self.bn4 = torch.nn.BatchNorm1d(k, eps=0.001)
self.mp4 = torch.nn.MaxPool1d(2)
self.activation5 = activation()
self.conv5 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
self.bn5 = torch.nn.BatchNorm1d(k, eps=0.001)
self.mp5 = torch.nn.MaxPool1d(2)
self.activation6 = activation()
self.conv6 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
self.bn6 = torch.nn.BatchNorm1d(k, eps=0.001)
self.mp6 = torch.nn.MaxPool1d(2)
self.activation7 = activation()
self.conv7 = torch.nn.Conv1d(k, k, kernel_size=1, bias=False)
self.bn7 = torch.nn.BatchNorm1d(k, eps=0.001)
self.activation8 = activation()
self.pf = PermuteFlatten()
self.fc = torch.nn.Linear(k*(1344 // (3 * 2 ** 6)), 32, bias=False)
self.bn8 = torch.nn.BatchNorm1d(32, eps=0.001)
self.activation9 = activation()
self.fc1 = torch.nn.Linear(32, n_outputs)
def forward(self, X):
X = self.activation(X)
X = self.activation1(self.mp(self.bn(self.conv(X))))
X = self.activation2(self.mp1(self.bn1(self.conv1(X))))
X = self.activation3(self.mp2(self.bn2(self.conv2(X))))
X = self.activation4(self.mp3(self.bn3(self.conv3(X))))
X = self.activation5(self.mp4(self.bn4(self.conv4(X))))
X = self.activation6(self.mp5(self.bn5(self.conv5(X))))
X = self.activation7(self.mp6(self.bn6(self.conv6(X))))
X = self.activation8(self.bn7(self.conv7(X)))
X = self.pf(X)
X = self.activation9(self.bn8(self.fc(X)))
X = self.fc1(X)
return X
model = TestModule(20, 288).double().eval()
X = torch.randn(1, 4, 1344).double() * 1000
reference = torch.randn(10, 4, 1344).double() * 1000
dl = DeepLiftShap(model)
X_attr, delta = dl.attribute(X, reference, target=0, return_convergence_delta=True)
y = model(X)
y_ref = model(reference)
print(y[:,0] - y_ref[:,0])
print(X_attr.sum())
print()
print(delta)
This should return:
tensor([-0.6306, 0.3156, 0.1099, 0.3692, 0.2338, 1.4345, 0.1836, -0.1585,
-1.2692, -1.3378], dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(-0.0749, dtype=torch.float64, grad_fn=<SumBackward0>)
tensor([ 3.3307e-16, -2.4980e-15, 5.9119e-15, -8.3267e-16, -2.3315e-15,
-7.9936e-15, -1.0769e-14, -2.1372e-15, 1.1546e-14, -7.7716e-15],
dtype=torch.float64)
If you change the activation to a ReLU:
activation = torch.nn.ReLU
Then you get:
tensor([ 0.0661, 0.0156, 0.0222, 0.0308, -0.0484, 0.0219, -0.0122, 0.0256,
0.0703, -0.0111], dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(0.0348, dtype=torch.float64, grad_fn=<SumBackward0>)
tensor([ 0.1226, -0.0261, -0.0167, -0.0194, -0.0280, 0.0237, -0.0341, -0.0470,
0.1896, 0.0032], dtype=torch.float64)
HOWEVER:
If you change the last two activations to be ReLUs manually, while keeping the rest as Identity, it converges. Using
class TestModule(torch.nn.Module):
def __init__(self, n_outputs, k):
super().__init__()
self.activation = torch.nn.ReLU()
self.conv = torch.nn.Conv1d(4, k, kernel_size=17, padding='same', bias=False)
self.bn = torch.nn.BatchNorm1d(k, eps=0.001)
self.mp = torch.nn.MaxPool1d(3)
self.activation1 = activation()
self.conv1 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
self.bn1 = torch.nn.BatchNorm1d(k, eps=0.001)
self.mp1 = torch.nn.MaxPool1d(2)
self.activation2 = activation()
self.conv2 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
self.bn2 = torch.nn.BatchNorm1d(k, eps=0.001)
self.mp2 = torch.nn.MaxPool1d(2)
self.activation3 = activation()
self.conv3 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
self.bn3 = torch.nn.BatchNorm1d(k, eps=0.001)
self.mp3 = torch.nn.MaxPool1d(2)
self.activation4 = activation()
self.conv4 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
self.bn4 = torch.nn.BatchNorm1d(k, eps=0.001)
self.mp4 = torch.nn.MaxPool1d(2)
self.activation5 = activation()
self.conv5 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
self.bn5 = torch.nn.BatchNorm1d(k, eps=0.001)
self.mp5 = torch.nn.MaxPool1d(2)
self.activation6 = activation()
self.conv6 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
self.bn6 = torch.nn.BatchNorm1d(k, eps=0.001)
self.mp6 = torch.nn.MaxPool1d(2)
self.activation7 = activation()
self.conv7 = torch.nn.Conv1d(k, k, kernel_size=1, bias=False)
self.bn7 = torch.nn.BatchNorm1d(k, eps=0.001)
self.activation8 = torch.nn.ReLU()
self.pf = PermuteFlatten()
self.fc = torch.nn.Linear(k*(1344 // (3 * 2 ** 6)), 32, bias=False)
self.bn8 = torch.nn.BatchNorm1d(32, eps=0.001)
self.activation9 = activation()
self.fc1 = torch.nn.Linear(32, n_outputs)
def forward(self, X):
X = self.activation(X)
X = self.activation1(self.mp(self.bn(self.conv(X))))
X = self.activation2(self.mp1(self.bn1(self.conv1(X))))
X = self.activation3(self.mp2(self.bn2(self.conv2(X))))
X = self.activation4(self.mp3(self.bn3(self.conv3(X))))
X = self.activation5(self.mp4(self.bn4(self.conv4(X))))
X = self.activation6(self.mp5(self.bn5(self.conv5(X))))
X = self.activation7(self.mp6(self.bn6(self.conv6(X))))
X = self.activation8(self.bn7(self.conv7(X)))
X = self.pf(X)
X = self.activation9(self.bn8(self.fc(X)))
X = self.fc1(X)
return X
returns
tensor([-1.2153, -0.2924, -0.5858, -0.2723, -0.0496, -0.7380, -0.7172, -1.2634,
-0.5631, -0.9969], dtype=torch.float64, grad_fn=<SubBackward0>)
tensor(-0.6694, dtype=torch.float64, grad_fn=<SumBackward0>)
tensor([-1.9984e-15, 5.6621e-15, 4.4409e-16, 6.6613e-16, 2.6437e-15,
5.2180e-15, -4.4409e-16, 0.0000e+00, 3.8858e-15, 3.3307e-16],
dtype=torch.float64)
Expected behavior
I would expect the convergence delta to be within machine precision of 0 regardless of the number of ReLU layers.
Environment
Describe the environment used for Captum
- Captum / PyTorch Version (e.g., 1.0 / 0.4.0): 0.6.0 / 1.13.1, respectively
- OS (e.g., Linux): Linux
- How you installed Captum / PyTorch (
conda,pip, source): pip - Build command you used (if compiling from source):
- Python version: 3.9.13
Additional context
@jmschrei, sorry that you are seeing this issue. I think that this is similar to the old issues that you were having. The problem is the reuse of the RELU because we are re-writing the gradients for RELUs but we are not re-writing the gradients for the Identity. https://github.com/pytorch/captum/blob/master/captum/attr/_core/deep_lift.py#L1045
In PyTorch when the hook a module we don't know where in execution graph it is called. If we reuse a module such as RELU and hook it then the grad input and output get messed up. @vivekmig and I discussed to fix this by counting the calls to the same module. The new versions of PyTorch might have other solutions for this that we need to look into.
If you change it to the below code where you don't reuse the RELUs you'll see small deltas
import torch
from captum.attr import DeepLiftShap
#activation = torch.nn.Identity
class PermuteFlatten(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, X):
return X.permute(0, 2, 1).reshape(X.shape[0], -1)
class TestModule(torch.nn.Module):
def __init__(self, n_outputs, k):
super().__init__()
self.activation = torch.nn.ReLU()
self.conv = torch.nn.Conv1d(4, k, kernel_size=17, padding='same', bias=False)
self.bn = torch.nn.BatchNorm1d(k, eps=0.001)
self.mp = torch.nn.MaxPool1d(3)
self.activation1 = torch.nn.ReLU()
self.conv1 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
self.bn1 = torch.nn.BatchNorm1d(k, eps=0.001)
self.mp1 = torch.nn.MaxPool1d(2)
self.activation2 = torch.nn.ReLU()
self.conv2 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
self.bn2 = torch.nn.BatchNorm1d(k, eps=0.001)
self.mp2 = torch.nn.MaxPool1d(2)
self.activation3 = torch.nn.ReLU()
self.conv3 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
self.bn3 = torch.nn.BatchNorm1d(k, eps=0.001)
self.mp3 = torch.nn.MaxPool1d(2)
self.activation4 = torch.nn.ReLU()
self.conv4 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
self.bn4 = torch.nn.BatchNorm1d(k, eps=0.001)
self.mp4 = torch.nn.MaxPool1d(2)
self.activation5 = torch.nn.ReLU()
self.conv5 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
self.bn5 = torch.nn.BatchNorm1d(k, eps=0.001)
self.mp5 = torch.nn.MaxPool1d(2)
self.activation6 = torch.nn.ReLU()
self.conv6 = torch.nn.Conv1d(k, k, kernel_size=5, padding='same', bias=False)
self.bn6 = torch.nn.BatchNorm1d(k, eps=0.001)
self.mp6 = torch.nn.MaxPool1d(2)
self.activation7 = torch.nn.ReLU()
self.conv7 = torch.nn.Conv1d(k, k, kernel_size=1, bias=False)
self.bn7 = torch.nn.BatchNorm1d(k, eps=0.001)
self.activation8 = torch.nn.ReLU()
self.pf = PermuteFlatten()
self.fc = torch.nn.Linear(k*(1344 // (3 * 2 ** 6)), 32, bias=False)
self.bn8 = torch.nn.BatchNorm1d(32, eps=0.001)
self.activation9 = torch.nn.ReLU()
self.fc1 = torch.nn.Linear(32, n_outputs)
def forward(self, X):
X = self.activation(X)
X = self.activation1(self.mp(self.bn(self.conv(X))))
X = self.activation2(self.mp1(self.bn1(self.conv1(X))))
X = self.activation3(self.mp2(self.bn2(self.conv2(X))))
X = self.activation4(self.mp3(self.bn3(self.conv3(X))))
X = self.activation5(self.mp4(self.bn4(self.conv4(X))))
X = self.activation6(self.mp5(self.bn5(self.conv5(X))))
X = self.activation7(self.mp6(self.bn6(self.conv6(X))))
X = self.activation8(self.bn7(self.conv7(X)))
X = self.pf(X)
X = self.activation9(self.bn8(self.fc(X)))
X = self.fc1(X)
return X
model = TestModule(20, 288).double().eval()
X = torch.randn(1, 4, 1344).double() * 1000
reference = torch.randn(10, 4, 1344).double() * 1000
dl = DeepLiftShap(model)
X_attr, delta = dl.attribute(X, reference, target=0, return_convergence_delta=True)
y = model(X)
y_ref = model(reference)
print(y[:,0] - y_ref[:,0])
print(X_attr.sum())
print()
print(delta)
Hi @NarineK , thanks for getting back to me. When I explicitly call torch.nn.ReLU() for each activation I don't see much difference. Is it working on your end?
Your code:
y[:,0] - y_ref[:,0]: tensor([-0.0183, -0.0890, -0.0488, -0.0487, -0.0483, 0.0016, -0.0432, -0.0276,
-0.1023, -0.0659], dtype=torch.float64, grad_fn=<SubBackward0>)
X_attr.sum(): tensor(-0.0518, dtype=torch.float64, grad_fn=<SumBackward0>)
deltas: tensor([-0.0475, 0.0608, -0.1042, -0.0746, -0.0878, 0.1340, 0.1053, 0.0663,
0.0264, -0.1065], dtype=torch.float64)
My code but with using ReLU instead of Identity:
y[:,0] - y_ref[:,0]: tensor([ 0.0837, -0.0177, 0.0327, 0.0403, 0.0081, 0.1140, 0.0547, 0.0875,
0.0992, 0.0937], dtype=torch.float64, grad_fn=<SubBackward0>)
X_attr.sum(): tensor(0.1714, dtype=torch.float64, grad_fn=<SumBackward0>)
deltas: tensor([ 0.0173, -0.0053, 0.0902, 0.2354, 0.1151, 0.1205, 0.0132, 0.1550,
0.2431, 0.1332], dtype=torch.float64)
My code using Identity:
y[:,0] - y_ref[:,0]: tensor([-3.4017, -5.3239, -5.5095, -4.8942, -2.2603, -3.9636, -5.7602, -2.6764,
-4.1172, -5.0349], dtype=torch.float64, grad_fn=<SubBackward0>)
X_attr.sum(): tensor(-4.2942, dtype=torch.float64, grad_fn=<SumBackward0>)
deltas: tensor([-5.3291e-15, -5.3291e-15, -2.1316e-14, -6.2172e-15, 2.6645e-15,
3.1086e-15, 2.6645e-15, -4.4409e-15, -7.9936e-15, -1.0658e-14],
dtype=torch.float64)
Hi, I also tried running the code on my side, seems all of them produce delta pretty close to 0.
NarineK's:
delta=tensor([-3.6082e-16, 5.8981e-17, 6.3144e-16, 2.1858e-16, 2.1858e-16,
-2.4460e-16, 1.7347e-17, 2.4286e-17, 8.3267e-17, 4.5797e-16],
dtype=torch.float64)
jmschrei's: ReLU
delta=tensor([-6.3144e-16, 1.1796e-16, 2.0817e-17, 8.3267e-17, -9.7145e-17,
-2.9143e-16, 2.4460e-16, -1.5266e-16, -1.1102e-16, 8.3267e-17],
dtype=torch.float64)
Identity
tensor([ 3.2196e-15, 5.1070e-15, 1.8874e-15, 6.2172e-15, 7.3275e-15,
-2.2204e-15, -2.2204e-16, 9.5479e-15, 1.2212e-14, 7.5495e-15],
dtype=torch.float64)
tested on pytorch 1.11.0 and captum 0.5.0
Great observation @yztxwd. If I downgrade to captum==0.5.0, it seems to work.
NarineK's
-3.2650e-05, 2.4080e-02, 5.3694e-02, -4.3929e-02, -1.2449e-02],
dtype=torch.float64, grad_fn=<SubBackward0>)
X_attr.sum(): tensor(0.0158, dtype=torch.float64, grad_fn=<SumBackward0>)
delta: tensor([-2.9143e-16, 1.9429e-16, -2.9143e-16, -6.5919e-17, -1.8041e-16,
4.1113e-16, 6.8695e-16, -3.8858e-16, 2.8449e-16, 7.1991e-16],
dtype=torch.float64)
mine with ReLUs
y[:,0] - y_ref[:,0]: tensor([ 0.0022, 0.0741, -0.0473, -0.0717, 0.0102, -0.0224, -0.0175, 0.0131,
-0.0086, 0.0529], dtype=torch.float64, grad_fn=<SubBackward0>)
X_attr.sum(): tensor(-0.0015, dtype=torch.float64, grad_fn=<SumBackward0>)
delta: tensor([ 5.3429e-16, 4.4409e-16, 1.1102e-16, 5.2736e-16, -4.5103e-16,
-6.9389e-18, 2.4286e-16, 7.2858e-17, -4.1980e-16, -2.0817e-16],
dtype=torch.float64)
However, I'm noticing that there are more warning messages with 0.5.0. Specifically, there are two additional ones related to max pooling being an "invalid module". Is it possible that whatever fix was implemented for this has a bug in it?
/users/jmschr/anaconda3/lib/python3.9/site-packages/captum/attr/_core/deep_lift.py:336: UserWarning: Setting forward, backward hooks and attributes on non-linear
activations. The hooks and attributes will be removed
after the attribution is finished
warnings.warn(
/users/jmschr/anaconda3/lib/python3.9/site-packages/captum/attr/_core/deep_lift.py:467: UserWarning: An invalid module MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False) is detected. Saved gradients will
be used as the gradients of the module's input tensor.
See MaxPool1d as an example.
warnings.warn(
/users/jmschr/anaconda3/lib/python3.9/site-packages/captum/attr/_core/deep_lift.py:467: UserWarning: An invalid module MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) is detected. Saved gradients will
be used as the gradients of the module's input tensor.
See MaxPool1d as an example.
warnings.warn(
Yes, I agree, that it is a PT version related. I'm using the master version of PT which is 2.0.0a0 version. I think that you shouldn't worry about the warnings. There was a bug in early version of PT which we fixed and send out that warning. https://github.com/pytorch/captum/commit/288cd3a6754d85cbcb0ce74784aa014876284b6c#diff-536b508d8c8a83f406729a6bd359e07d9798cc922d3ad3bf2d3501dde23033a4L467 @vivekmig switched to tensor hooks and I think that there might be an issue with that for your version, @jmschrei! Let me ping, Vivek! Tensor hooks are more reliable than module hooks that's why we switched to them in captum's latest version but there might be a bug related to it. We'll debug and fix it. Thank you for bringing it up.
Sounds good. Thanks for all your quick responses.