Laplace
Laplace copied to clipboard
Help for Running Laplace on Image Segmentation Tasks
Hello,
I am using a U-Net augmentation (specifically: https://github.com/juntang-zhuang/LadderNet) to perform segmentation of hands. To be specific, I am classifying each pixel of an image to one of five classes (no hand, my right hand, my left hand, your right hand, your left hand.)
This requires my prob shape (in fisher.py: 446) to be [batch_size, img_h, img_w, n_classes] -> [8,32,32,5] (snippet below)
def __fisher_exact(loss_and_backward, model, probs):
_, n_classes = probs.shape
Because of this dimensionality, this line of code fails. I assume it's because it expects the probs.shape tuple to be (img_as_tensor, label_as_int) per the CIFAR example: https://github.com/AlexImmer/Laplace/blob/main/examples/calibration_example.py) where the CIFAR dataset object returns a a tuple of (#_examples, (img_as_tensor: [3,32,32], label_as_int: 0,1,2,3,etc.).
I can always reshape my training data to be a tuple of that format, but because I am classifying by pixel, my associated label would not be a single integer. It would have to be in the same shape as the image tensor [32,32].
So I'm asking this question to the community to see if anyone has attempted this kind of segmentation task using the laplace-torch package before I try to force a solution within the asdfghjkl/fisher.py
file.
Update: I have gotten the following lines of code to run by modifying my LadderNet model to follow this architecture:
LadderNetv6(
(initial_block): Initial_LadderBlock(
(inconv): Conv2d(1, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(down_module_list): ModuleList(
(0): BasicBlock(
(conv1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu): ReLU()
(drop): Dropout2d(p=0.25, inplace=False)
)
(1): BasicBlock(
(conv1): Conv2d(20, 20, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu): ReLU()
(drop): Dropout2d(p=0.25, inplace=False)
)
(2): BasicBlock(
(conv1): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu): ReLU()
(drop): Dropout2d(p=0.25, inplace=False)
)
(3): BasicBlock(
(conv1): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu): ReLU()
(drop): Dropout2d(p=0.25, inplace=False)
)
)
(down_conv_list): ModuleList(
(0): Conv2d(10, 20, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(1): Conv2d(20, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(2): Conv2d(40, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(3): Conv2d(80, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
)
(bottom): BasicBlock(
(conv1): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu): ReLU()
(drop): Dropout2d(p=0.25, inplace=False)
)
(up_conv_list): ModuleList(
(0): ConvTranspose2d(160, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(1): ConvTranspose2d(80, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(2): ConvTranspose2d(40, 20, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(3): ConvTranspose2d(20, 10, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
)
(up_dense_list): ModuleList(
(0): BasicBlock(
(conv1): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu): ReLU()
(drop): Dropout2d(p=0.25, inplace=False)
)
(1): BasicBlock(
(conv1): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu): ReLU()
(drop): Dropout2d(p=0.25, inplace=False)
)
(2): BasicBlock(
(conv1): Conv2d(20, 20, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu): ReLU()
(drop): Dropout2d(p=0.25, inplace=False)
)
(3): BasicBlock(
(conv1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu): ReLU()
(drop): Dropout2d(p=0.25, inplace=False)
)
)
)
(final_block): Final_LadderBlock(
(block): LadderBlock(
(inconv): BasicBlock(
(conv1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu): ReLU()
(drop): Dropout2d(p=0.25, inplace=False)
)
(down_module_list): ModuleList(
(0): BasicBlock(
(conv1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu): ReLU()
(drop): Dropout2d(p=0.25, inplace=False)
)
(1): BasicBlock(
(conv1): Conv2d(20, 20, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu): ReLU()
(drop): Dropout2d(p=0.25, inplace=False)
)
(2): BasicBlock(
(conv1): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu): ReLU()
(drop): Dropout2d(p=0.25, inplace=False)
)
(3): BasicBlock(
(conv1): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu): ReLU()
(drop): Dropout2d(p=0.25, inplace=False)
)
)
(down_conv_list): ModuleList(
(0): Conv2d(10, 20, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(1): Conv2d(20, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(2): Conv2d(40, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(3): Conv2d(80, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
)
(bottom): BasicBlock(
(conv1): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu): ReLU()
(drop): Dropout2d(p=0.25, inplace=False)
)
(up_conv_list): ModuleList(
(0): ConvTranspose2d(160, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(1): ConvTranspose2d(80, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(2): ConvTranspose2d(40, 20, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(3): ConvTranspose2d(20, 10, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
)
(up_dense_list): ModuleList(
(0): BasicBlock(
(conv1): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu): ReLU()
(drop): Dropout2d(p=0.25, inplace=False)
)
(1): BasicBlock(
(conv1): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu): ReLU()
(drop): Dropout2d(p=0.25, inplace=False)
)
(2): BasicBlock(
(conv1): Conv2d(20, 20, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu): ReLU()
(drop): Dropout2d(p=0.25, inplace=False)
)
(3): BasicBlock(
(conv1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu): ReLU()
(drop): Dropout2d(p=0.25, inplace=False)
)
)
)
)
(final_fc): Final_Layer(
(layer): Linear(in_features=10, out_features=5, bias=False)
)
)
la = Laplace(net, 'classification', subset_of_weights='last_layer', hessian_structure='kron', backend=AsdlGGN)
la.fit(train_loader)
la.optimize_prior_precision(method='marglik')
However, despite successfully training my model, running predictions on my model, running .fit() for post-hoc laplace on my model, I cannot run the following code without error:
@torch.no_grad()
def predict(dataloader, model, laplace=False):
"""
this code was taken from the calibration_example.py
"""
py = []
for x, _ in dataloader:
if laplace:
py.append(model(x.cuda()))
else:
py.append(torch.softmax(model(x.cuda()), dim=-1))
return torch.cat(py).cpu
probs_laplace = predict(test_loader, la, laplace=True) # this line fails
The following is the trace when running the predict() code:
RuntimeError Traceback (most recent call last)
Input In [12], in <cell line: 65>()
50 # # TODO: specify val_loader
51 # # From API docs page
52 # # post-hoc update:
(...)
61
62 # From GitHub CIFAR example:
63 la.optimize_prior_precision(method='marglik') #, val_loader=test_loader_copy)
---> 65 probs_laplace = predict(test_loader_copy, la, laplace=True) # in future, replace w/test set: test_loader
67 acc_laplace = (probs_laplace.argmax(-1) == targets).float().mean()
69 # ece_laplace = ECE(bins=15).measure(probs_laplace.numpy(), targets.numpy())
70
71 # nll_laplace = -dists.Categorical(probs_laplace).log_prob(targets).mean()
72
73
74 # print(f'[Laplace] Acc.: {acc_laplace:.1%}; ECE: {ece_laplace:.1%}; NLL: {nll_laplace:.3}')')
File ~/.conda/envs/hrc-laddernet/lib/python3.10/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
24 @functools.wraps(func)
25 def decorate_context(*args, **kwargs):
26 with self.clone():
---> 27 return func(*args, **kwargs)
Input In [12], in predict(dataloader, model, laplace)
12 print(x.shape)
13 if laplace:
---> 14 py.append(model(x.cuda()))
15 else:
16 py.append(torch.softmax(model(x.cuda()), dim=-1))
File ~/.conda/envs/hrc-laddernet/lib/python3.10/site-packages/laplace/baselaplace.py:536, in ParametricLaplace.__call__(self, x, pred_type, link_approx, n_samples)
533 raise ValueError(f'Unsupported link approximation {link_approx}.')
535 if pred_type == 'glm':
--> 536 f_mu, f_var = self._glm_predictive_distribution(x)
537 # regression
538 if self.likelihood == 'regression':
File ~/.conda/envs/hrc-laddernet/lib/python3.10/site-packages/laplace/lllaplace.py:124, in LLLaplace._glm_predictive_distribution(self, X)
122 print(Js.shape)
123 print(f_mu.shape)
--> 124 f_var = self.functional_variance(Js)
125 print('shape of f_var, which is variance(Js)')
126 print(f_var.shape)
File ~/.conda/envs/hrc-laddernet/lib/python3.10/site-packages/laplace/baselaplace.py:841, in KronLaplace.functional_variance(self, Js)
840 def functional_variance(self, Js):
--> 841 return self.posterior_precision.inv_square_form(Js)
File ~/.conda/envs/hrc-laddernet/lib/python3.10/site-packages/laplace/utils/matrix.py:411, in KronDecomposed.inv_square_form(self, W)
409 print('from laplace/utils inv_square_form')
410 print(W.shape)
--> 411 SW = self._bmm(W, exponent=-1)
412 return torch.bmm(W, SW.transpose(1, 2))
File ~/.conda/envs/hrc-laddernet/lib/python3.10/site-packages/laplace/utils/matrix.py:404, in KronDecomposed._bmm(self, W, exponent)
402 print('length of SW')
403 print(len(SW))
--> 404 SW = torch.cat(SW, dim=1).reshape(B, K, P)
405 return SW
RuntimeError: shape '[1024, 32, 320]' is invalid for input of size 1638400
I'm somewhat at a loss because I did not expect this model to fail if the .fit() function performed properly. Any help would be greatly appreciated.
Hi @SouLeo, multi-output models are indeed still in our backlog. For now, I think this paper https://arxiv.org/abs/2206.15078 along with the code can be very useful for you https://github.com/FrederikWarburg/LaplaceAE.
Oh, I see. So this framework is not suited for multiclass labels for a single image?
I'll review the items you have linked. Thank you very much!
I am still somewhat confused that I was able to perform the following lines of code without error:
la = Laplace(net, 'classification', subset_of_weights='last_layer', hessian_structure='kron', backend=AsdlGGN)
la.fit(train_loader)
la.optimize_prior_precision(method='marglik')
but cannot run the model prediction. Do you have any thoughts on this?
@SouLeo Were you able to use this library successfully for image segmentation?