zennit
zennit copied to clipboard
Draft: Core: Second Order Gradients
- support computing the gradient of the modified gradient via second order gradients
- for this, the second gradient pass must be done without hooks
TODO
- finalize
- write tests
- write docs
Implements #142 and fixes #125
Hi @chr5tphr
I will provide my feedback here instead of the issue itself.
I didn't check the attribution maps in detail, but your code example (here) seems to work fine. (My following examples are based on this example.)
There are some points I would like to mention:
- What if we do not use composites?
- Why not add the possibility to activate/deactivate the rule hooks in general? (Sometimes you need to do many other things before you calculate the second order gradient and you do not want to keep this always in mind.)
- Why does
composite.inactiveonly works inside the originalcomposite.context?
Code Example 1
canonizers = None
composite = EpsilonGammaBox(low=-3., high=3., canonizers=canonizers)
def explain_LRP(model: torch.nn.Module, input: torch.Tensor, target: torch.Tensor):
with composite.context(model) as modified_model:
outputs = modified_model(input)
relevance, = torch.autograd.grad(outputs, input, target, create_graph=True)
return outputs, relevance
outputs, relevance = explain_LRP(model, input, target)
# create a target heatmap, rolled 12 pixels south east
target_heat = torch.roll(relevance.detach(), (12, 12), (2, 3))
loss = ((relevance - target_heat) ** 2).mean()
# deactivate the rule hooks in order to leave the second order gradient untouched
# version 1
with composite.inactive():
adv_grad, = torch.autograd.grad(loss, input) # <<-- error because `hook.active` is still True
# version 2
with composite.context(model):
with composite.inactive():
adv_grad, = torch.autograd.grad(loss, input) # <<-- error because `hook.active` is still True
- Attributors are currently not supported?
Code Example 2
canonizers = None
composite = EpsilonGammaBox(low=-3., high=3., canonizers=canonizers)
with Gradient(model=model, composite=composite) as attributor:
outputs, relevance = attributor(inputs, torch.eye(1000)[targets])
# create a target heatmap, rolled 12 pixels south east
target_heat = torch.roll(relevance.detach(), (12, 12), (2, 3))
loss = ((relevance - target_heat) ** 2).mean()
# deactivate the rule hooks in order to leave the second order gradient untouched
with attributor.composite.inactive():
adv_grad, = torch.autograd.grad(loss, inputs) # <<-- Error
# RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
- What do you think should be the default behavour? (Maybe the most important question 😄) Would it not be more plausible if the default is inactive and only if you enter a attributor or composite it would be active? In my optinion the following should be possible by default:
Code Example 3
def explain_LRP(model: torch.nn.Module, input: torch.Tensor, target: torch.Tensor):
canonizers = None
composite = EpsilonGammaBox(low=-3., high=3., canonizers=canonizers)
with Gradient(model=model, composite=composite) as attributor:
outputs, attributions = attributor(input, target)
return outputs, attributions
outputs, relevance = explain_LRP(model, input, target)
# create a target heatmap, rolled 12 pixels south east
target_heat = torch.roll(relevance.detach(), (12, 12), (2, 3))
loss = ((relevance - target_heat) ** 2).mean()
# now the gradient calculation should be possible by default without any futher deactivation etc.
adv_grad, = torch.autograd.grad(loss, input) # <<-- this should work by default
Code Example 4
def explain_LRP(model: torch.nn.Module, input: torch.Tensor, target: torch.Tensor):
canonizers = None
composite = EpsilonGammaBox(low=-3., high=3., canonizers=canonizers)
with composite.context(model) as modified_model:
outputs = modified_model(input)
relevance, = torch.autograd.grad(outputs, input, target, create_graph=True)
return outputs, relevance
outputs, relevance = explain_LRP(model, input, target)
# create a target heatmap, rolled 12 pixels south east
target_heat = torch.roll(relevance.detach(), (12, 12), (2, 3))
loss = ((relevance - target_heat) ** 2).mean()
# now the gradient calculation should be possible by default without any futher deactivation etc.
adv_grad, = torch.autograd.grad(loss, input) # <<-- this should work by default
Edit: To be honest I expected something like this to be the default:
Code Example 5
def explain_LRP(model: torch.nn.Module, input: torch.Tensor, target: torch.Tensor):
canonizers = None
composite = EpsilonGammaBox(low=-3., high=3., canonizers=canonizers)
with composite.context(model) as modified_model:
outputs = modified_model(input)
relevance, = torch.autograd.grad(outputs, input, target, create_graph=True)
for hook in composite.hook_refs:
hook.active = False
return outputs, relevance
outputs, relevance = explain_LRP(model, input, target)
# create a target heatmap, rolled 12 pixels south east
target_heat = torch.roll(relevance.detach(), (12, 12), (2, 3))
loss = ((relevance - target_heat) ** 2).mean()
# now the gradient calculation should be possible by default without any futher deactivation etc.
adv_grad, = torch.autograd.grad(loss, input) # <<-- OK
Hey @HeinrichAD
thanks a lot for your feedback!
For clarification: Hooks are only ever active after composite.register and before composite.remove has been called, which is done at the beginning and the end of the with composite context respectively. The new composite.inactive context adds a new attribute hook.active, which is only used when the hook already exists. It is True by default, because it only exists to temporarily deactivate the hook, which is necessary to compute the second order gradient, with which the hooks interfere normally, while the hook is still alive.
- What if we do not use composites?
- If you are not going to use the hook again, you can simply remove it. Otherwise you can manually set
hook.active = Falseafter creating the hook, and set it back toTrueafter you are done. Since you normally use more than a single hook (one per layer), I think this is more convenient than using contexts for all hooks. Bundled hooks, after all, are supposed to be managed by composites.
- Why not add the possibility to activate/deactivate the rule hooks in general? (Sometimes you need to do many other things before you calculate the second order gradient and you do not want to keep this always in mind.)
- While a global flag would certainly be possible to deactivate all hooks in general, I am not too fond of the idea. The main goal of this approach to me seems to use Hooks without composites. I would like to encourage usage of Composites, and if they do not cover some use-case well, I would rather like to better adapt them for that use-case. I was thinking of a by-reference Composite before.
- Why does
composite.inactiveonly works inside the originalcomposite.context?
- Outside the
composite.context, the hooks do not exist, they are created at the beginning of the context and destroyed at the end.
Code Example 1
# deactivate the rule hooks in order to leave the second order gradient untouched # version 1 with composite.inactive(): adv_grad, = torch.autograd.grad(loss, input) # <<-- error because `hook.active` is still True # version 2 with composite.context(model): with composite.inactive(): adv_grad, = torch.autograd.grad(loss, input) # <<-- error because `hook.active` is still True
- Here, composite is not registered anymore, the hooks do not exist anymore (intended behavior), and re-entering will only create new hooks, so this example will not work. Code Example 5 is supposed to work, however, there is a bug in the code: I store the references to the (torch-)hooks registered to the tensors via weak references to their respective tensor. I assumed, as long as the backward-graph exists, the tensor will have a reference count > 1 and not be garbage-collected. However, this is not the case. There does not seem to be a reliable way to keep a weak reference to the handles. I did this to not create an infinitely growing list of handles when doing multiple backward-passes in a single composite context, like, for example, on a server that endlessly computes attributions. The current code simply does not clean them up correctly, thus your confusion.
- Attributors are currently not supported?
Code Example 2
- This is an oversight of me. I will add some way to specify to create a graph in gradient-based attributors.
- What do you think should be the default behavour? (Maybe the most important question smile) Would it not be more plausible if the default is inactive and only if you enter a attributor or composite it would be active? In my optinion the following should be possible by default:
Code Example 3
Code Example 4
- Since the hooks do not exist before entering a composite context (i.e. before creating and registering them), they cannot be active. Their default behaviour, upon creation, is intended to be active (within their life-time).
composite.inactiveis supposed to temporarily turn them off in order to compute second order derivatives without destroying the hooks. I agree the example should work.
Edit: To be honest I expected something like this to be the default: Code Example 5
- Again, I agree it should work like you expect. I was too focused on the case where the composite is not destroyed when computing the second order derivative.
To summarize, the second order gradient should be possible to compute after destroying the hooks, or while the hooks still exist within with composite.inactive(): or hook.active = False; compute_gradient(); hook.active = True for single hooks. A bug in the code made the destroying-case impossible.
Edit:
Actually, for the intended behaviour, you do not need to loop over the hooks in Code Example 5, it should work without it, since you leave the context and destroy the hooks.
The composite.inactive exactly does the loop you show, but it only makes sense if you set them to True again later when re-using the composite.
Hi @chr5tphr,
It makes much more sense now. Thank you for the detailed clarification. If you want I can start a test run if you think it makes sense now. Otherwise I wait a little longer.
Hey @HeinrichAD
I'm currently working on the rest of the documentation for this, but functionality and tests are now finished (unless I find a bug or something missing).
If you would like to try it out, you can either do so now, or wait a little bit until I also finished the documentation, at which point I will mark this PR ready and merge it in the following days.
Everything should work as expected, and as a bonus, Attributors now also have a .inactive function to compute second order gradients within the with block. There's also a new rule from (Dombrowski et. al., 2019) to change the gradient of ReLU to its smooth variant in order to deal with its otherwise zero (and undefined at zero) second order gradients.
Hey @HeinrichAD
I am done with the PR and would like to merge. If you have not checked already, you can see if it works for you as expected. A preview of the documentation is also available here. Otherwise I will just merge.
Hi @chr5tphr, I will try to find some time tomorrow.
Hi @chr5tphr,
First thank you for your effort!
Code
In general, the code looks good. It also works as I expected. Only one thing is confussion me: The output of your example code from the issue #142 generates a complete different output than before (for the 2nd derivative).
Same code generates now this output:

NOTE: I do not know which output is correct.
Typos
Since my IDE already points this out for me, here is a list of typos:
- CONTRIBUTING.md#50 numpy codestyle
- docs/source/index.rst#5 Propagation
- docs/source/getting-started.rst#149 instantiate
- docs/source/how-to/visualize-results.rst#444 accessed
- docs/source/how-to/write-custom-canonizer.rst#113 torch
- docs/source/tutorial/image-...ipynb section 3.2 cell 1 line 10 gamma
- src/zennit/core.py#165 lengths
- tests/test_attribution.py#91 preferred
- tests/test_attribution.py#140 SmoothGrad
- tests/test_canonizers.py#120 AttributeCanonizer
- tests/test_canonizers.py#141 whether
- shared/scripts/palette_fit.py#48 brightness
Also sometimes it's layer-wise relevance propagation and sometimes layerwise relevance propagation.
Hey @HeinrichAD
thanks a lot again for your feedback.
In general, the code looks good. It also works as I expected. Only one thing is confussion me: The output of your example code from the issue #142 generates a complete different output than before (for the 2nd derivative).
Same code generates now this output:
NOTE: I do not know which output is correct.
The first version was actually wrong. The problem was that BasicHook, within the backward function, when using torch.autograd.grad, did not set create_graph=True, which meant that the gradient was only computed such that the contribution-weighting of the input was handled like a constant. This means that the resulting second order gradient was not computed through the whole model, but just the first layer. This is why the gradient also looked so clean; it was just a difference of the contribution in the first layer (divided by x).
Typos
Since my IDE already points this out for me, here is a list of typos:
* CONTRIBUTING.md#50 numpy codestyle * docs/source/index.rst#5 Propagation * docs/source/getting-started.rst#149 instantiate * docs/source/how-to/visualize-results.rst#444 accessed * docs/source/how-to/write-custom-canonizer.rst#113 torch * docs/source/tutorial/image-...ipynb section 3.2 cell 1 line 10 gamma * src/zennit/core.py#165 lengths * tests/test_attribution.py#91 preferred * tests/test_attribution.py#140 SmoothGrad * tests/test_canonizers.py#120 AttributeCanonizer * tests/test_canonizers.py#141 whether * shared/scripts/palette_fit.py#48 brightnessAlso sometimes it's layer-wise relevance propagation and sometimes layerwise relevance propagation.
I will add a quick follow-up PR to fix these, since many files where not touched in this PR, and I prefer to not touch files for typos etc. if there was no change in that file.
Thank you for the explanation. In this case the PR gets a ready to go from my side 😄.