zennit icon indicating copy to clipboard operation
zennit copied to clipboard

Core: Second Order Gradients

Open chr5tphr opened this issue 2 years ago • 4 comments

I am currently working on supporting second-order gradients, i.e. gradients of the modified gradients, which is used for example to compute adversarial explanations. The current issue which prevents second order gradients is that the gradient modification introduced by rules will also be applied when in the second-order backward pass. This will prevented by disabling the modification temporarily when computing the second-order gradient, likely using something like a no_modification context for composites/attributors/rules.

As also pointed out in #125, handles are not stored for the backward hooks for tensors. Storing and removing the hooks before the second-order backward pass would correctly compute the modified-gradient derivatives, although then the same graph cannot be used to compute the modified gradient for a different gradient output. By adding the context, I am considering to also enable complete removal of the tensor backward hooks.

chr5tphr avatar Jun 03 '22 15:06 chr5tphr

How far has the work progressed? Actually, this is excactly want I would need for my research. It would be awesome if zennit had this functionality or at least a workaround for further research and testing.

Edit: I tried and failed to get the gradient of and attribution map (relevances).

HeinrichAD avatar Aug 24 '22 08:08 HeinrichAD

Hey @HeinrichAD

thanks for the bump!

I have pushed a draft version in #159 and will try to get around finishing this up the following days. I have not yet fully tested this, but if you are feeling brave, you can try it out:

$ pip install git+https://github.com/chr5tphr/zennit.git@second-order-gradients

Here's an example of how to compute the gradient (of a function of) the attribution wrt. the input:

Code Example
import os

import torch
from torchvision.models import vgg11
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, CenterCrop
from PIL import Image

from zennit.composites import EpsilonGammaBox
from zennit.image import imgify


fname = 'dornbusch-lighthouse.jpg'

if not os.path.exists(fname):
    torch.hub.download_url_to_file(
        'https://upload.wikimedia.org/wikipedia/commons/thumb/8/8b/2006_09_06_180_Leuchtturm.jpg/640px-2006_09_06_181_Leuchtturm.jpg',
        fname,
    )

# define the base image transform
transform_img = Compose([
    Resize(256),
    CenterCrop(224),
])
# define the normalization transform
transform_norm = Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
# define the full tensor transform
transform = Compose([
    transform_img,
    ToTensor(),
    transform_norm,
])

# load the image
image = Image.open('dornbusch-lighthouse.jpg')

# transform the PIL image and insert a batch-dimension
data = transform(image)[None]

model = vgg11(weights='DEFAULT')
composite = EpsilonGammaBox(low=-3., high=3.)

input = data.clone().requires_grad_(True)
target = torch.eye(1000)[[437]]
with composite.context(model) as modified_model:
    out = modified_model(input)
    relevance, = torch.autograd.grad(out, input, target, create_graph=True)
    # create a target heatmap, rolled 12 pixels south east
    target_heat = torch.roll(relevance.detach(), (12, 12), (2, 3))
    loss = ((relevance - target_heat) ** 2).mean()
    # deactivate the rule hooks in order to leave the second order gradient untouched
    with composite.inactive():
        adv_grad, = torch.autograd.grad(loss, input)

imgify(relevance[0].detach().sum(0), cmap='coldnhot', symmetric=True).show()
imgify(target_heat[0].sum(0), cmap='coldnhot', symmetric=True).show()
imgify(adv_grad[0].sum(0), cmap='coldnhot', symmetric=True).show()

image image image

chr5tphr avatar Aug 24 '22 12:08 chr5tphr

Thank you for the fast response and for your commitment. I will test it, maybe not today, but tomorrow. I will let you know if I succeeded (or not).

HeinrichAD avatar Aug 24 '22 15:08 HeinrichAD

I added my comments/feedback to the pull request itself.

HeinrichAD avatar Aug 25 '22 08:08 HeinrichAD