backpack
backpack copied to clipboard
Extend backpack to deal with weighted sums
Hi,
I would like to apply the DiagGGNExact
to a custom module that does not have any parameters. I attached a minimal example to reproduce the scenario I'd like to have. The model produces a set of weights with which combine (with a weighted sum) the downstream predictions. However, when extending the nn.Module
in charge of computing this weighted sum I get an error because this module should also extend for second order operations.
Can anyone help me with extending this simple module?
Here there is a minimal code example of what I'm trying to achieve:
import torch
from backpack import backpack, extend, extensions
torch.manual_seed(0)
class SumModule(torch.nn.Module):
def forward(self, x, w):
return torch.sum(x * w, dim=-2)
B = 5 # batch size
S = 10 # number of intermediate elements
I = 16 # input size
O = 4 # output size
x = torch.randn((B, S, I)) # some simple inputs
y = torch.ones((B, O)) * torch.arange(B)[..., None]
# some outputs, note the shape is not dependent on the intermediate outputs S
# base_model produces both the weight and an intermediate embedding for each input
base_model = torch.nn.Sequential(
torch.nn.Linear(I, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, 16)
)
# head_model produces the final prediction from the intermediate embedding
head_model = torch.nn.Sequential(
torch.nn.Linear(15, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, O)
)
base_model = extend(base_model)
head_model = extend(head_model)
sum_module = extend(SumModule())
loss_fun = torch.nn.MSELoss()
with backpack(extensions.DiagGGNExact()):
x = x.reshape(-1, I)
intermediate_output = base_model(x)
# split the base_model output into weigths w and intermediate embedding h
w, h = torch.split(
intermediate_output,
[1, 15], dim=-1
)
w = w.reshape(B, S, 1) # weights
pred = head_model(h).reshape(B, S, O) # predictions
pred_y = sum_module(pred, w)
loss = loss_fun(y, pred_y)
loss.backward()
Here there is the error I get running this script:
NotImplementedError: Extension saving to diag_ggn_exact does not have an extension for Module <class '__main__.SumModule'>
Hi, thanks for your question!
Could you let me know w.r.t. which parameters you would like to compute the GGN diagoal?
Is it only the head_model
's parameters, or also the base_model
's?
Both head_model
and base_model
parameters
In that case you will have to write a module extension of DiagGGNExact
for your SumModule
, as well as a module that performs the torch.split
. This is because the GGN diagonal is a second-order extension, meaning that your computation graph must consist entirely of nn.Module
s for all of which BackPACK knows how to backpropagate the information for the GGN diagonal.
There is work on documenting how to write module extensions for new layers in #320, but I haven't had time to review and merge it yet. You could take a look at it and start from there. The steps would be
- Write a
SplitModule
layer such thatx, w = split_module(intermediate_output)
- Write a
DiagGGNExactSplitModule
extension which specifies how to backpropagate information through aSplitModule
when computing the GGN diagoal (see #320) - Repeat step 2. but for your
SumModule
.
Given that my SumModule
has to deal with two inputs not just a single one should I also fork from the multiple-inputs referred in the issue #306 ?
Yes, that sound right.
After some more thinking, I believe the split
can be done using BackPACK's custom Slicing
module which already supports DiagGGNExact
. This should fix one problem and leave you with step 3.
I have updated the previous example with my implementation of the weighted SumModule
. I'm not entirely sure it is correct, in particular because I found a bit confusing how the MSELossDerivatives
expands the decomposition of the Hessian of the loss, is there any particular reason it is done in this way instead of backpropagate an B,C,C
matrix?
Also, as you suggested, I tried to use the Slicing
module in order to do the torch.split
operation, however I got this error:
ValueError: Slicing the batch axis is not supported.
I think I may be using it in the wrong way right now, but I haven't found anything in the documentation and just looking into the code it seems I should specify the whole shape (batch axis included) as slice_info
for the module. Do you have any suggestion about how to fix it?
Here, there is the updated version of the previous example which is failing:
from typing import Tuple, List
import torch
from torch import nn
from backpack import BackpropExtension, backpack, extend, extensions
from backpack.extensions.module_extension import ModuleExtension
from torch.nn import Module
from torch import Tensor
import einops
from backpack.custom_module.slicing import Slicing
torch.manual_seed(0)
class SumModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, w):
return torch.sum(x * w, dim=-2)
class DiagGGNSumModule(ModuleExtension):
def backpropagate(
self,
extension: BackpropExtension,
module: nn.Module,
g_inp: Tuple[Tensor],
g_out: Tuple[Tensor],
bpQuantities: torch.Any) -> torch.Any:
inputs = self.get_inputs(module)
x = inputs[0]
w = inputs[1]
sqrt_ggn = bpQuantities
# J_w = x.T
JwTS = torch.einsum("bic, cbk -> kbi", x, sqrt_ggn)
JwTS = einops.rearrange(JwTS, "k b s -> k (b s) 1")
# J_c = [w .. w]
JxT = einops.repeat(w, "b i 1 -> b i c1 c2", c1=x.shape[-1], c2=x.shape[-1])
JxTS = torch.einsum("bsjc, cbk -> kbsj", JxT, sqrt_ggn)
JxTS = einops.rearrange(JxTS, "k b s j -> k (b s) j")
return tuple([JxTS, JwTS])
@staticmethod
def get_inputs(module: nn.Module) -> List[Tensor]:
"""Get all inputs of ``MultiplyModule``'s forward pass."""
layer_inputs = []
i = 0
while hasattr(module, f"input{i}"):
layer_inputs.append(getattr(module, f"input{i}"))
i += 1
return layer_inputs
B = 5 # batch size
S = 10 # number of intermediate elements
I = 16 # input size
O = 4 # output size
sum_module = extend(SumModule())
x = torch.randn((B, S, I)) # some simple inputs
y = torch.ones((B, O)) * torch.arange(B)[..., None]
# some outputs, note the shape is not dependent on the intermediate outputs S
# base_model produces both the weight and an intermediate embedding for each input
base_model = torch.nn.Sequential(
torch.nn.Linear(I, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, 16)
)
# head_model produces the final prediction from the intermediate embedding
head_model = torch.nn.Sequential(
torch.nn.Linear(15, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, O)
)
base_model = extend(base_model)
head_model = extend(head_model)
sum_module = extend(SumModule())
slice_0 = extend(Slicing((slice(0, B*S), 1)))
slice_1_16 = extend(Slicing((slice(0, B*S), slice(1, 16))))
ext = extensions.DiagGGNExact()
ext.set_module_extension(SumModule, DiagGGNSumModule())
loss_fun = extend(torch.nn.MSELoss())
with backpack(ext, debug=True):
x = x.reshape(-1, I)
intermediate_output = base_model(x)
# split the base_model output into weigths w and intermediate embedding h
# w, h = torch.split(
# intermediate_output,
# [1, 15], dim=-1
# )
w = slice_0(intermediate_output)
h = slice_1_16(intermediate_output)
pred = head_model(h) # predictions
w = w.reshape(B, S, 1) # weights
pred = pred.reshape(B, S, O)
pred_y = sum_module(pred, w)
loss = loss_fun(y, pred_y)
loss.backward()
Hi, thanks for the update.
- Regarding the format of the backpropagated quantity for
DiagGGN
: The shape is[C, N, *]
where the first axis indexes the columns of the loss function's Hessian andN
is the batch size. This is why the backpropagated quantity byMSELoss
is[C, N, C]
, and not[N, C, C]
(this also helps internally, as the firstC
-dimensional axes can be treated as avmap
-ed one). - Regarding the slicing module, I believe the current
slicing_info
is incorrect because in your initial code you are splitting alongdim=-1
, which is not the batch axis. You can check the tests ofSlicing
to see how the syntax forslicing_info
works. For a 2d input, I believe something like
should work.slicing_info_0 = (slice(None), slice(0, 1)) slicing_info_1_16 =(slice(None), slice(1, 16))
Thanks for the answer!
I used the slicing_info_*
you suggested, however it doesn't crash only when I do the slicing in this order
h = slice_1_16(intermediate_output)
w = slice_0(intermediate_output)
while if I swap these two lines I get an error of shape mismatch.
I believe the issue comes from how the branch multiple-inputs
from #306 handles the saved bpQuantities
, do you know if this could be the reason and how to fix it in that case?
Hi, I don't really understand why the order of slicing should matter. Do you get the error in the forward pass?
No, I get the error during the backpropagation step
Could you post a minimal example that reproduces your error and append the traceback?
Of course!
from typing import Tuple, List
import torch
from torch import nn
from backpack import BackpropExtension, backpack, extend, extensions
from backpack.extensions.module_extension import ModuleExtension
from torch.nn import Module
from torch import Tensor
import einops
from backpack.custom_module.slicing import Slicing
torch.manual_seed(0)
class SumModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, w):
return torch.sum(x * w, dim=-2)
class DiagGGNSumModule(ModuleExtension):
def backpropagate(
self,
extension: BackpropExtension,
module: nn.Module,
g_inp: Tuple[Tensor],
g_out: Tuple[Tensor],
bpQuantities: torch.Any) -> torch.Any:
inputs = self.get_inputs(module)
x = inputs[0]
w = inputs[1]
sqrt_ggn = bpQuantities
# J_w = x.T
JwTS = torch.einsum("bic, cbk -> kbi", x, sqrt_ggn)
JwTS = einops.rearrange(JwTS, "k b s -> k (b s) 1")
# J_c = [w .. w]
JxT = einops.repeat(w, "b i 1 -> b i c1 c2", c1=x.shape[-1], c2=x.shape[-1])
JxTS = torch.einsum("bsjc, cbk -> kbsj", JxT, sqrt_ggn)
JxTS = einops.rearrange(JxTS, "k b s j -> k (b s) j")
return tuple([JxTS, JwTS])
@staticmethod
def get_inputs(module: nn.Module) -> List[Tensor]:
"""Get all inputs of ``MultiplyModule``'s forward pass."""
layer_inputs = []
i = 0
while hasattr(module, f"input{i}"):
layer_inputs.append(getattr(module, f"input{i}"))
i += 1
return layer_inputs
B = 5 # batch size
S = 10 # number of intermediate elements
I = 2 # input size
O = 2 # output size
sum_module = extend(SumModule())
x = torch.randn((B, S, I)) # some simple inputs
y = torch.ones((B, O)) * torch.arange(B)[..., None]
# some outputs, note the shape is not dependent on the intermediate outputs S
# base_model produces both the weight and an intermediate embedding for each input
base_model = torch.nn.Sequential(
torch.nn.Linear(I, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, 16)
)
# head_model produces the final prediction from the intermediate embedding
head_model = torch.nn.Sequential(
torch.nn.Linear(15, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, O)
)
base_model = extend(base_model)
head_model = extend(head_model)
sum_module = extend(SumModule())
slice_0 = extend(Slicing((slice(None), slice(0, 1))))
slice_1_16 = extend(Slicing((slice(None), slice(1, 16))))
ext = extensions.DiagGGNExact()
ext.set_module_extension(SumModule, DiagGGNSumModule())
loss_fun = extend(torch.nn.MSELoss())
with backpack(ext, debug=True):
x = x.reshape(-1, I)
intermediate_output = base_model(x)
# split the base_model output into weigths w and intermediate embedding h
w = slice_0(intermediate_output)
h = slice_1_16(intermediate_output)
pred = head_model(h) # predictions
w = w.reshape(B, S, 1) # weights
pred = pred.reshape(B, S, O)
pred_y = sum_module(pred, w)
loss = loss_fun(y, pred_y)
loss.backward()
And here the error message:
[DEBUG] Running extension <backpack.extensions.secondorder.diag_ggn.DiagGGNExact object at 0x7f9f74d6d730> on MSELoss()
[DEBUG] Running extension hook on MSELoss()
[DEBUG] Running extension <backpack.extensions.secondorder.diag_ggn.DiagGGNExact object at 0x7f9f74d6d730> on SumModule()
[DEBUG] Running extension hook on SumModule()
[DEBUG] Running extension <backpack.extensions.secondorder.diag_ggn.DiagGGNExact object at 0x7f9f74d6d730> on Linear(in_features=64, out_features=2, bias=True)
[DEBUG] Running extension hook on Linear(in_features=64, out_features=2, bias=True)
[DEBUG] Running extension <backpack.extensions.secondorder.diag_ggn.DiagGGNExact object at 0x7f9f74d6d730> on ReLU()
[DEBUG] Running extension hook on ReLU()
[DEBUG] Running extension <backpack.extensions.secondorder.diag_ggn.DiagGGNExact object at 0x7f9f74d6d730> on Linear(in_features=15, out_features=64, bias=True)
[DEBUG] Running extension hook on Linear(in_features=15, out_features=64, bias=True)
[DEBUG] Running extension <backpack.extensions.secondorder.diag_ggn.DiagGGNExact object at 0x7f9f74d6d730> on Slicing()
[DEBUG] Running extension hook on Slicing()
[DEBUG] Running extension <backpack.extensions.secondorder.diag_ggn.DiagGGNExact object at 0x7f9f74d6d730> on Slicing()
Traceback (most recent call last):
File "/home/projects/miniconda3/envs/nerfbackpack/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/projects/miniconda3/envs/nerfbackpack/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/projects/.vscode-server/extensions/ms-python.debugpy-2024.2.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
cli.main()
File "/home/projects/.vscode-server/extensions/ms-python.debugpy-2024.2.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
run()
File "/home/projects/.vscode-server/extensions/ms-python.debugpy-2024.2.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
runpy.run_path(target, run_name="__main__")
File "/home/projects/.vscode-server/extensions/ms-python.debugpy-2024.2.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
return _run_module_code(code, init_globals, run_name,
File "/home/projects/.vscode-server/extensions/ms-python.debugpy-2024.2.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
_run_code(code, mod_globals, init_globals,
File "/home/projects/.vscode-server/extensions/ms-python.debugpy-2024.2.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
exec(code, run_globals)
File "/home/projects/scratch/minimal_example.py", line 111, in <module>
loss.backward()
File "/home/projects/miniconda3/envs/nerfbackpack/lib/python3.8/site-packages/torch/_tensor.py", line 487, in backward
torch.autograd.backward(
File "/home/projects/miniconda3/envs/nerfbackpack/lib/python3.8/site-packages/torch/autograd/__init__.py", line 200, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/home/projects/miniconda3/envs/nerfbackpack/lib/python3.8/site-packages/torch/utils/hooks.py", line 137, in hook
out = hook(self.module, res, self.grad_outputs)
File "/home/projects/scratch/backpack/backpack/__init__.py", line 209, in hook_run_extensions
backpack_extension(module, g_inp, g_out)
File "/home/projects/scratch/backpack/backpack/extensions/backprop_extension.py", line 131, in __call__
module_extension(self, module, g_inp, g_out)
File "/home/projects/scratch/backpack/backpack/extensions/module_extension.py", line 125, in __call__
bp_quantity = self.backpropagate(
File "/home/projects/scratch/backpack/backpack/extensions/mat_to_mat_jac_base.py", line 55, in backpropagate
return self.derivatives.jac_t_mat_prod(
File "/home/projects/scratch/backpack/backpack/core/derivatives/shape_check.py", line 133, in _wrapped_mat_prod_accept_vectors
mat_out = mat_prod(self, module, g_inp, g_out, mat_in, *args, **kwargs)
File "/home/projects/scratch/backpack/backpack/core/derivatives/shape_check.py", line 189, in wrapped_mat_prod_check_shapes
in_check(mat, module, *args, **kwargs)
File "/home/projects/scratch/backpack/backpack/core/derivatives/shape_check.py", line 80, in _check_like
return check_shape(mat, compare, diff=diff)
File "/home/projects/scratch/backpack/backpack/core/derivatives/shape_check.py", line 49, in check_shape
raise RuntimeError(
RuntimeError: ('Compared shapes [50, 16] and [50, 1] do not match. ', 'Got [2, 50, 16] and [50, 1]')
One thing you definitely have to fix are the current reshape
s in the forward pass. These reshape
s are functional
s, but BackPACK's second-order extension require all operations of the forward pass to be performed through nn.Module
s. Otherwise, the backpropagation mechanism will break. I believe what you're currently seeing is that one of BackPACK's internal checks for correct shapes fails because you are modifying tensors with functional
s, which are 'invisible' to BackPACK.
I propose simplifying your current example, because you are trying to solve two different problems:
- Adding support for your custom
SumModule
, which has multiple inputs - Slicing the same tensor twice (accumulating backpropagated quantities)
Maybe you can start as following simpler scenario which does not suffer from aspect 2.:
linear1 = Linear(...)
linear2 = Linear(...)
linear3 = Linear(...)
X1, X2 = rand(...), rand(...)
# (...) extend both
# both should already have correct shapes (no reshape)
w = linear1(X1)
h = linear2(X2)
pred = linear3(h)
pred_y = sum_module(pred, w)
loss = loss_fun(y, pred_y)
loss.backward()
Then check if you get the correct GGN diagonals for the parameters of linear1,2,3
.