shap icon indicating copy to clipboard operation
shap copied to clipboard

BUG: Backward hook with shap.DeepExplainer on simple models with PyTorch

Open PietroManganelliConforti opened this issue 1 year ago • 6 comments

Issue Description

Hi,

I'm trying to implement a Deep Explainer for a Resnet50 imported from Torchvision and executed on cifar100. The basic implementation is not working because of an in-place modification not supported by Pytorch (but probably in Keras/TF). Do you know if there is a way to fix this? The error I saw was a warning in the past, but now it stops my code.

Thanks in advance.

Minimal Reproducible Example

import torch
from torchvision import models, datasets, transforms
import shap

model = models.resnet50(pretrained=True)

dataset = datasets.CIFAR100(root="./",
                            download=True,
                            train=True,
                            transform=transforms.ToTensor())
    
dataloader  = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True, num_workers=1)
    

for images, labels in dataloader: break

e = shap.DeepExplainer(model, images )
shap_values = e.shap_values( images )

Traceback

Traceback (most recent call last):
  File "xai.py", line 339, in <module>
    shap_values = e.shap_values(test_images)
  File "/usr/local/lib/python3.8/dist-packages/shap/explainers/_deep/__init__.py", line 125, in shap_values
    return self.explainer.shap_values(X, ranked_outputs, output_rank_order, check_additivity=check_additivity)
  File "/usr/local/lib/python3.8/dist-packages/shap/explainers/_deep/deep_pytorch.py", line 191, in shap_values
    sample_phis = self.gradient(feature_ind, joint_x)
  File "/usr/local/lib/python3.8/dist-packages/shap/explainers/_deep/deep_pytorch.py", line 107, in gradient
    outputs = self.model(*X)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/work/project/models/resnet.py", line 168, in forward
    x = self.relu(x)    # 32x32
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1128, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/activation.py", line 98, in forward
    return F.relu(input, inplace=self.inplace)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py", line 1440, in relu
    result = torch.relu_(input)
RuntimeError: Output 0 of BackwardHookFunctionBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Function.

Expected Behavior

No response

Bug report checklist

  • [X] I have checked that this issue has not already been reported.
  • [X] I have confirmed this bug exists on the latest release of shap.
  • [x] I have confirmed this bug exists on the master branch of shap.
  • [ ] I'd be interested in making a PR to fix this bug

Installed Versions

I just installed SHAP with PIP, so I'm using the lastest available.

Thanks for the reproducible example. I can confirm this on master.

I will have a look at this in the upcoming weeks.

CloseChoice avatar Jan 23 '24 20:01 CloseChoice

Puuhh, this seems pretty tough to fix. I tried to set every ReLU activation to inplace=False but that still fails at the line

out += identity

of the resnet module with the same error. Strangely this works if I replace it with

out = out + identity

but then the shapes mismatch.

CloseChoice avatar Feb 12 '24 22:02 CloseChoice

I am attempting to apply DeepExplainer for resnet50 and densenet121 and I am getting BackwardHookFunctionBackward error msg. I tried to modify the ReLU activation to disable the inplace operation by setting it to False. I am using the following approach to update the inplace:

class ModifiedDenseNet121(nn.Module):
    def __init__(self, pretrained=True, input_channels=3, num_classes=1000):
        super(ModifiedDenseNet121, self).__init__()
        self.densenet = models.densenet121(weights=DenseNet121_Weights.DEFAULT if pretrained else None)
        if input_channels != 3:
            self.densenet.features[0] = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.densenet.classifier = nn.Linear(self.densenet.classifier.in_features, num_classes)
        # Fix the usage of ReLU activation with inplace=True
        for module in self.densenet.modules():
            if isinstance(module, nn.ReLU):
                module.inplace = False
    def forward(self, x):
        return self.densenet(x)

However, I'm still encountering a runtime error: RuntimeError: Output 0 of BackwardHookFunctionBackward is a view and is being modified inplace. Any guidance or updates on how to resolve this error would be greatly appreciated. Thank you!

samiraat avatar May 02 '24 02:05 samiraat

If you really need this maybe have a look at captum. Would be great if you could report back your findings, so that we can borrow a couple of ideas from them

CloseChoice avatar May 02 '24 08:05 CloseChoice

+1. Any update?

rubencart avatar Sep 13 '24 11:09 rubencart

I'm trying to use it on efficientnet-based model, had to go to efficientnet.py under torchvision.models to modify:

result += input

to

result = result + input

And it worked.

dvirla avatar Sep 19 '24 03:09 dvirla

https://github.com/shap/shap/issues/3725#issuecomment-2202052886

Found the solution: for every layers make sure they are not forwarded twice. Especially for activation functions which are typically stored as a class attribute and used multiple times in the forward method.

For shape mismatch, this solution may be useful. It works on ResNet50/101/152.

wsynuiag avatar Nov 11 '24 07:11 wsynuiag

#3725 (comment)

Found the solution: for every layers make sure they are not forwarded twice. Especially for activation functions which are typically stored as a class attribute and used multiple times in the forward method.

For shape mismatch, this solution may be useful. It works on ResNet50/101/152.

How did you fix it in ResNet152 exactly? I'm having the same problem with shape mismatches (but in the backward of a ReLU) The grad_output has 512 channels instead of the 2048 channels in the forward pass 😨

nicogross avatar Nov 12 '24 20:11 nicogross

#3725 (comment)

Found the solution: for every layers make sure they are not forwarded twice. Especially for activation functions which are typically stored as a class attribute and used multiple times in the forward method.

For shape mismatch, this solution may be useful. It works on ResNet50/101/152.

How did you fix it in ResNet152 exactly? I'm having the same problem with shape mismatches (but in the backward of a ReLU) The grad_output has 512 channels instead of the 2048 channels in the forward pass 😨

First I changed the out += identity to out = out + identity as above. Then I found the ReLU layer in BottleNeck is used multiple times. So I replace the original self.relu in Bottleneck.init by:

        self.relu1 = nn.ReLU(inplace=True)
        self.relu2 = nn.ReLU(inplace=True)
        self.relu3 = nn.ReLU(inplace=True)

and the forward function:

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = out + identity
        out = self.relu3(out)

        return out

It's worth mentioning that my experiments are based on LAV models (no bugs at least), where ResNet is just the backbone.

wsynuiag avatar Nov 13 '24 02:11 wsynuiag

#3725 (comment)

Found the solution: for every layers make sure they are not forwarded twice. Especially for activation functions which are typically stored as a class attribute and used multiple times in the forward method.

For shape mismatch, this solution may be useful. It works on ResNet50/101/152.

How did you fix it in ResNet152 exactly? I'm having the same problem with shape mismatches (but in the backward of a ReLU) The grad_output has 512 channels instead of the 2048 channels in the forward pass 😨

First I changed the out += identity to out = out + identity as above. Then I found the ReLU layer in BottleNeck is used multiple times. So I replace the original self.relu in Bottleneck.init by:

        self.relu1 = nn.ReLU(inplace=True)
        self.relu2 = nn.ReLU(inplace=True)
        self.relu3 = nn.ReLU(inplace=True)

and the forward function:

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = out + identity
        out = self.relu3(out)

        return out

It's worth mentioning that my experiments are based on LAV models (no bugs at least), where ResNet is just the backbone.

Thanks for your fast reply. I did the sam changes, but now I'm getting assertion errors in the additivity check. I set check_additivity=False since the results make sense to me and resnet152 is deep (maybe some rounding errors occured..), but I'm wondering how errors like this happen, e.g. 0.034852 < 0.01 (Tolerance)! Did you have problems with the additivity of the computed shapley values too?

nicogross avatar Nov 13 '24 18:11 nicogross

Yes, #3725 also mentioned this error. If you figure it out I would be interested as well.

wsynuiag avatar Nov 14 '24 03:11 wsynuiag

Yes, #3725 also mentioned this error. If you figure it out I would be interested as well.

The problem here is the adapted backpropagation through maxpool (the only maxpool before layer1). When I let SHAP explain the input to layer1, the sum adds up to the difference, but when I let it explain the input to the maxpool, it doesn't add up anymore

nicogross avatar Nov 25 '24 10:11 nicogross

Seems like the unpooling can't handle overlapping regions (when stride < kernel_size) zero-padding seems to be fine.

Some code I experimented with:

import torch
import torch.nn as nn
import shap

x = torch.randn([1,64,112,112])*100 # in [-100,100]
r = torch.randn([1,64,112,112])*100

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        # change kernel_size, stride and padding here
        self.maxpool = nn.MaxPool2d(kernel_size=10, stride=10, padding=5)
        self.avgpool = nn.AdaptiveAvgPool3d((1,1,1))
    def forward(self, x):
        out = self.maxpool(x)
        out = self.avgpool(out)
        return out

model = MyModel()

explainer   = shap.DeepExplainer(model, r)
shap_values = explainer.shap_values(x, check_additivity=False)

print(shap_values.sum())
print((model(x) - model(r)).item())

There is undefined behaviour for repeated input indices (see this open pytorch Issue):

Note: This operation may behave nondeterministically when the input indices has repeat values. See https://github.com/pytorch/pytorch/issues/80827 and Reproducibility for more information.

nicogross avatar Nov 25 '24 11:11 nicogross

maybe https://github.com/shap/shap/issues/1479#issuecomment-9322340825 can explain this error, it seems to be caused by the reused layer

1zero224 avatar Nov 27 '24 14:11 1zero224

I'm trying to use it on efficientnet-based model, had to go to efficientnet.py under torchvision.models to modify:

result += input

to

result = result + input

And it worked.

Thanks. This does resolve the issue.

Davido111200 avatar Dec 09 '24 08:12 Davido111200