zennit
zennit copied to clipboard
support unet with Upsample or ConvTranspose2d layer?
I want to use LRP to explain the semantic segmentation task using Unet model (Pytorch). I tested the LRP in captum but not support nn.Upsample
and nn.ConvTranspose2d
. I would like to know if the semantic segmentation model like Unet can be supported, and if not, how should it be implemented? Any help would be appreciated!
Hey @bugsuse ,
I think this should simply just work if you use any of the composite rules.
Since the gradient of nn.Upsample
will just be constant depending on the size upsampled, this will probably scale the full attribution by a little, but this should not be a problem. nn.ConvTranspose2d
is simply a linear layer, and therefore is supported.
If you simply use one of the Composites
like EpsilonGammaBox
or EpsilonAlpha2Beta1Flat
(or anything, really), this should just work out of the box, though I did not try it with UNet specifically yet. Unless you are using BatchNorm, you also do not have to supply a Canonizer. Have a look at the example.
Just in case the gradients get too large with the nn.Upsample
, you can use the Norm
rule and build your own composite:
import torch
from zennit.rules import Gamma, Epsilon, ZBox, Norm
from zennit.types import Convolution
from zennit.composites import SpecialFirstLayerMapComposite, LAYER_MAP_BASE
class UpsampledEpsilonGammaBox(SpecialFirstLayerMapComposite):
'''An explicit composite using the ZBox rule for the first convolutional layer, gamma rule for all following
convolutional layers, and the epsilon rule for all fully connected layers.
Additionally, this uses the `Norm` rule for `nn.Upsample`.
Parameters
----------
low: obj:`torch.Tensor`
A tensor with the same size as the input, describing the lowest possible pixel values.
high: obj:`torch.Tensor`
A tensor with the same size as the input, describing the highest possible pixel values.
'''
def __init__(self, low, high, canonizers=None):
layer_map = LAYER_MAP_BASE + [
(Convolution, Gamma(gamma=0.25)),
(torch.nn.Linear, Epsilon()),
(torch.nn.Upsample, Norm()),
]
first_map = [
(Convolution, ZBox(low, high))
]
super().__init__(layer_map, first_map, canonizers=canonizers)
But I would just try it with the gradient first, i.e., use one of the built-in composites.
Hey @bugsuse,
I have already used Zennit for a UNet with the nn.Upsample
layer and as @chr5tphr said, this worked even when just using the gradient (no extra rule). In fact, in my case it did not make a difference whether I used the gradient, Norm() or Epsilon() rule for nn.Upsample
.
@chr5tphr Thanks sooo much! I will try it!
@maxdreyer It's great! Could you have a relevant notebook or example blog? Is it possible to share with me?
@bugsuse,
you can basically use the code that is written in zennit/share/example/feed_forward.py
. After loading your model and data, you only have to adapt the choice of the output_relevance (the output that is propagated backwards).
For classification tasks, it makes sense to propagate zeros everywhere but the index of the target class: output_relevance = torch.eye(n_outputs, device=device)[target]
For segmentation, it depends on what you are interested in. In my case, I tried to propagate backwards the output channel of a specific class like
output = model(input)
output_relevance = torch.zeros_like(output)
output_relevance[:, class_index, :, :] = output[:, class_index, :, :]
@maxdreyer I tested it according to feed_forwadr.py
but raised RuntimeError,
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-44-740b8083f87f> in <module>
45
46 # this will compute the modified gradient of model, with the on
---> 47 output, relevance = attributor(data.cuda(), output_relevance.cuda())
48
49 # sum over the color channel for visualization
~/tools/miniconda3/envs/pytorch/lib/python3.8/site-packages/zennit/attribution.py in __call__(self, input, attr_output)
130
131 if self.composite is None or self.composite.handles:
--> 132 return self.forward(input, attr_output_fn)
133
134 with self:
~/tools/miniconda3/envs/pytorch/lib/python3.8/site-packages/zennit/attribution.py in forward(self, input, attr_output_fn)
175 input = input.detach().requires_grad_(True)
176 output = self.model(input)
--> 177 gradient, = torch.autograd.grad((output,), (input,), grad_outputs=(attr_output_fn(output.detach()),))
178 return output, gradient
179
~/tools/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/autograd/__init__.py in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused)
200 retain_graph = create_graph
201
--> 202 return Variable._execution_engine.run_backward(
203 outputs, grad_outputs_, retain_graph, create_graph,
204 inputs, allow_unused)
~/tools/miniconda3/envs/pytorch/lib/python3.8/site-packages/zennit/core.py in wrapper(grad_input, grad_output)
139 @functools.wraps(self.backward)
140 def wrapper(grad_input, grad_output):
--> 141 return hook_ref().backward(module, grad_input, hook_ref().stored_tensors['grad_output'])
142
143 if not isinstance(input, tuple):
~/tools/miniconda3/envs/pytorch/lib/python3.8/site-packages/zennit/core.py in backward(self, module, grad_input, grad_output)
279 input = in_mod(original_input).requires_grad_()
280 with mod_params(module, param_mod, **param_kwargs) as modified, torch.autograd.enable_grad():
--> 281 output = modified.forward(input)
282 output = out_mod(output)
283 inputs.append(input)
~/tools/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/conv.py in forward(self, input)
421
422 def forward(self, input: Tensor) -> Tensor:
--> 423 return self._conv_forward(input, self.weight)
424
425 class Conv3d(_ConvNd):
~/tools/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight)
417 weight, self.bias, self.stride,
418 _pair(0), self.dilation, self.groups)
--> 419 return F.conv2d(input, weight, self.bias, self.stride,
420 self.padding, self.dilation, self.groups)
421
RuntimeError: Expected 4-dimensional input for 4-dimensional weight [16, 24, 3, 3], but got 3-dimensional input of size [16, 64, 64] instead
the example code is as follow,
device = 'cuda:0'
model.to(device)
model.eval()
# disable requires_grad for all parameters, we do not need their modified gradients
for param in model.parameters():
param.requires_grad = False
output = model(input.cuda())
output_relevance = torch.zeros_like(output)
# create a composite if composite_name was set, otherwise we do not use a composite
composite = None
if composite_name is not None:
composite_kwargs = {}
if composite_name == 'upsample_epsilon_gamma_box':
# the maximal input shape, needed for the ZBox rule
shape = (batch_size, 64, 64)
# the highest and lowest pixel values for the ZBox rule
composite_kwargs['low'] = torch.zeros(*shape, device=device)
composite_kwargs['high'] = torch.ones(*shape, device=device)
# use torchvision specific canonizers, as supplied in the MODELS dict
composite_kwargs['canonizers'] = [MODELS[model_name][1]()]
# create a composite specified by a name; the COMPOSITES dict includes all preset composites provided by zennit.
composite = COMPOSITES[composite_name](**composite_kwargs)
# specify some attributor-specific arguments
attributor_kwargs = {
'smoothgrad': {'noise_level': 0.1, 'n_iter': 20},
'integrads': {'n_iter': 20},
'occlusion': {'window': (56, 56), 'stride': (28, 28)},
}.get(attributor_name, {})
attributor = ATTRIBUTORS[attributor_name](model, composite, **attributor_kwargs)
sample_index = 0
with attributor:
for data, target in valid_loader:
output_relevance = torch.zeros_like(torch.squeeze(target))
output, relevance = attributor(data.cuda(), output_relevance.cuda())
## the rest is the same as `feed_forward.py`
...
Thanks for your kindly help!
I expect you already made sure the model runs when doing a forward pass without Zennit? (I.e., without the Attributor context) Could you supply the model you are using? There may be a problem when passing tuples instead of tensors between layers, is this the case for you?
@bugsuse also, you set the output_relevance twice:
in the beginning output_relevance = torch.zeros_like(output)
and in the loop later output_relevance = torch.zeros_like(torch.squeeze(target))
output_relevance should have the same shape as output. torch.zeros_like(torch.squeeze(target))
could have the wrong shape.
@chr5tphr Yeah, I load directly pretrained weights. I'm using Unet model. More codes is below,
class UNetCanonizer(SequentialMergeBatchNorm):
'''Canonizer for torchvision.models.vgg* type models. This is so far identical to a SequentialMergeBatchNorm'''
MODELS = {
'vgg16': (vgg16, VGGCanonizer),
'vgg16_bn': (vgg16_bn, VGGCanonizer),
'resnet50': (resnet50, ResNetCanonizer),
'unet': (Unet(), UNetCanonizer)
}
ATTRIBUTORS = {
'gradient': Gradient,
'smoothgrad': SmoothGrad,
'integrads': IntegratedGradients,
'occlusion': Occlusion,
}
class BatchNormalize:
def __init__(self, mean, std, device=None):
self.mean = torch.tensor(mean, device=device)[None, :, None, None]
self.std = torch.tensor(std, device=device)[None, :, None, None]
def __call__(self, tensor):
return (tensor - self.mean) / self.std
class AllowEmptyClassImageFolder(ImageFolder):
'''Subclass of ImageFolder, which only finds non-empty classes, but with their correct indices given other empty
classes. This counter-acts the changes in torchvision 0.10.0, in which DatasetFolder does not allow empty classes
anymore by default. Versions before 0.10.0 do not expose `find_classes`, and thus this change does not change the
functionality of `ImageFolder` in earlier versions.
'''
def find_classes(self, directory):
with os.scandir(directory) as scanit:
class_info = sorted((entry.name, len(list(os.scandir(entry.path)))) for entry in scanit if entry.is_dir())
class_to_idx = {class_name: index for index, (class_name, n_members) in enumerate(class_info) if n_members}
if not class_to_idx:
raise FileNotFoundError(f'No non-empty classes found in \'{directory}\'.')
return list(class_to_idx), class_to_idx
COMPOSITES.update({'upsample_epsilon_gamma_box': UpsampledEpsilonGammaBox})
model = Unet(num_channels_in, num_channels_out)
model = model.load_from_checkpoint('results/weight/unet_ci-unet-epoch=28-val_loss=101.73.ckpt',
hparams_file='results/log/unet/version_0/hparams.yaml')
attributor_name = 'gradient'
composite_name = 'upsample_epsilon_gamma_box'
model_name = 'unet'
batch_size = 16
shuffle = False
relevance_norm = 'symmetric'
cmap = 'coldnhot'
level = 1.0
seed = 21
cpu = True
# create a composite if composite_name was set, otherwise we do not use a composite
composite = None
if composite_name is not None:
composite_kwargs = {}
if composite_name == 'upsample_epsilon_gamma_box':
# the maximal input shape, needed for the ZBox rule
shape = (batch_size, 64, 64)
# the highest and lowest pixel values for the ZBox rule
composite_kwargs['low'] = torch.zeros(*shape, device=device)
composite_kwargs['high'] = torch.ones(*shape, device=device)
# use torchvision specific canonizers, as supplied in the MODELS dict
composite_kwargs['canonizers'] = [MODELS[model_name][1]()]
# create a composite specified by a name; the COMPOSITES dict includes all preset composites provided by zennit.
composite = COMPOSITES[composite_name](**composite_kwargs)
# specify some attributor-specific arguments
attributor_kwargs = {
'smoothgrad': {'noise_level': 0.1, 'n_iter': 20},
'integrads': {'n_iter': 20},
'occlusion': {'window': (56, 56), 'stride': (28, 28)},
}.get(attributor_name, {})
# create an attributor, given the ATTRIBUTORS dict given above. If composite is None, the gradient will not be
# modified for the attribution
attributor = ATTRIBUTORS[attributor_name](model, composite, **attributor_kwargs)
# the current sample index for creating file names
sample_index = 0
# the accuracy
accuracy = 0.
# enter the attributor context outside the data loader loop, such that its canonizers and hooks do not need to be
# registered and removed for each step. This registers the composite (and applies the canonizer) to the model
# within the with-statement
with attributor:
for data, target in valid_loader:
# we use data without the normalization applied for visualization, and with the normalization applied as
# the model input
output_relevance = torch.zeros_like(torch.squeeze(target))
# this will compute the modified gradient of model, with the on
output, relevance = attributor(data.cuda(), output_relevance.cuda())
# sum over the color channel for visualization
relevance = np.array(relevance.sum(1).detach().cpu())
# normalize between 0. and 1. given the specified strategy
if relevance_norm == 'symmetric':
# 0-aligned symmetric relevance, negative and positive can be compared, the original 0. becomes 0.5
amax = np.abs(relevance).max((1, 2), keepdims=True)
relevance = (relevance + amax) / 2 / amax
elif relevance_norm == 'absolute':
# 0-aligned absolute relevance, only the amplitude of relevance matters, the original 0. becomes 0.
relevance = np.abs(relevance)
relevance /= relevance.max((1, 2), keepdims=True)
elif relevance_norm == 'unaligned':
# do not align, the orignal minimum value becomes 0., the orignal maximum becomes 1.
rmin = relevance.min((1, 2), keepdims=True)
rmax = relevance.max((1, 2), keepdims=True)
relevance = (relevance - rmin) / (rmax - rmin)
for n in range(len(data)):
fname = relevance_format.format(sample=sample_index + n)
# zennit.image.imsave will create an appropriate heatmap given a cmap specification
imsave(fname, relevance[n], vmin=0., vmax=1., level=level, cmap=cmap)
if input_format is not None:
fname = input_format.format(sample=sample_index + n)
# if there are 3 color channels, imsave will not create a heatmap, but instead save the image with
# its appropriate colors
imsave(fname, data[n])
sample_index += len(data)
# update the accuracy
accuracy += (output.argmax(1) == target).sum().detach().cpu().item()
accuracy /= len(dataset)
print(f'Accuracy: {accuracy:.2f}')
There may be a problem when passing tuples instead of tensors between layers, is this the case for you?
Are you saying that I should supply the tensor sizes for each layer of unet model as tuple to attributor
?
@maxdreyer yeah, but I'm sure output_relevance
has the same shape as output. I think that the error maybe was raised due to middle layers, such as Upsample?
Are you saying that I should supply the tensor sizes for each layer of unet model as tuple to
attributor
?
No, I was only referring to whether your layers in your UNet only produce single outputs, or if there are multiple ones.
Anyway, beyond the sanity check that torch.zeros_like(torch.squeeze(target))
and output
must have the same shape, I can only guess the problem you are having without knowing the precise code of your UNet and shape of your dataset. Try running pdb and see if all the shapes are correct. The problem you are having seems to be related to the batch-dimension getting consumed somewhere.
@chr5tphr The UNet model has single outputs with shape (batch_size, 1, width, height), which width and height are both 64. All of the codes and data have been uploaded colab now, including UNet model and test data. I have tried to check it according to your suggestions. Could you help me to check it? Thanks a lot!
@chr5tphr @maxdreyer You both are right! The UNet model requires 5-D input and return 3-D output result in the problem above. I have fixed it by changing the input and output both from 5-D to 4-D, but the results are so strange (as shown below). The relevance of the input channels does not seem to be consistent with the model output.
Could this be influenced by BatchNorm
layer, or anything else?
Do you only propagate zeroes back? You should try it the way @maxdreyer suggested, but since you only have a single class, I would try to simply use the full output (or ground_truth), i.e.
output_relevance = target.to(device)
# or
output_relevance = model(data.to(device))
Otherwise, if this still produces unexpected results, maybe try epsilon_alpha2_beta1_flat
or epsilon_plus
, see how those behave.
@chr5tphr The result is that according to @maxdreyer suggested. I will try it epsilon_alpha2_beta1_flat
or epsilon_plus
rule. Thanks a lot!
@maxdreyer May I ask what rules you are using? Is the result what is expected?
@bugsuse using the rule epsilon_plus_flat
I received a heatmap that was similar to the output mask.
@maxdreyer Thanks so much! I will test it!
I tried it using epsilon_plus_flat
rule, but the result is strange. The relevance always seems to be at the center and does not correspond to the input. What could be the cause of it?
I do not really know anything about the data, but this could potentially mean the prediction for the positive mask values is very local. Alternatively you could try to set all pixels of the output_relevance
to 1 for the target channel.
Hi, @bugsuse I see that you use the SequentialMergeBatchNorm canonizer. As far as I understand, this canonizer is only used for BatchNorm Layers. Does your UNet implementation contain torch.add layers ?
@chr5tphr for torch.cat we do not need a canonizer, right?
Best
@rachtibat Have a look at the supplied code in colab, the UNet implemementation includes BatchNorms, and the SequentialMergeBatchNorm is applied correctly (the BatchNorms in DoubleConv are assigned in the same order they are called, thus they will be detected correctly).
No, concatenation does not need a canonizer.
But I just noticed, that in the Down
Module
class Down(nn.Module):
'''Downscaling with conv(stride = 2) then double conv'''
def __init__(self, in_channels, out_channels):
super().__init__()
self.pool_conv = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=2, padding=1),
DoubleConv(in_channels, out_channels)
)
there's two subsequent linear modules (three with the BatchNorm, at the beginning of DoubleConv), which technically would need to be merged to have a canonical form independent of the implementation, similarly to the BatchNorm. However, I am not sure whether this would have a significant effect. Another alternative would be to use the epsilon rule in those layers. @maxdreyer Can you check your UNet architecture and see whether you have something similar?
@chr5tphr I'd like to predict future semantic segmentation using past multiple frames multiple channels satellite observation.
I tested it to set all pixels of the output_relevance to 1 for the target channel, but the result has without significant difference.
Here is the complete code, pretrained weights, and test data. Hope it helps!
@chr5tphr You are right, that might be a problem! In my implementation, a nn.MaxPool2d(2)
is used instead of nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=2, padding=1)
.
@bugsuse thus, another option would also be to replace the Conv2d layer with a MaxPool2d layer and do retraining. If this is too much work, one could of course also merge the two subsequent linear layers or try different rules as @chr5tphr has proposed.
@maxdreyer Thanks a lot! I will test it!
@bugsuse Hi, maybe you can try out the new pull request #45
with git fetch origin pull/45/head:YOUR_BRANCH_NAME
This produces much better heatmaps with LRP
Version 0.3.3 is live on pypi, you can just update to check.
@chr5tphr @rachtibat Cool! I will try it!