captum
captum copied to clipboard
how to use layergradcam for a multi-label application?
I'm working with a multi-label image dataset. My inputs have the following shape: torch.Size([3, 224, 224]); and my targets are all 1x33 tensors one-hot encoded, as in the following example:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0')
I'm trying to use the LayerGradCam() in the following manner:
layer_gradcam = LayerGradCam(linear_classifier, linear_classifier.sigm)
for batch_inp, batch_target in test_loader:
batch_inp = batch_inp.cuda(non_blocking=True)
batch_target = batch_target.cuda(non_blocking=True)
for inp, target in zip(batch_inp, batch_target):
attributions_lgc = layer_gradcam.attribute(inputs=inp, target=target)
and keep getting the following error:
RuntimeError: mat1 dim 1 must match mat2 dim 0
I'm currently trying to implement the solution proposed in https://github.com/pytorch/captum/issues/171 but I get a python error when trying to create the tuple.
layer_gradcam = LayerGradCam(linear_classifier, linear_classifier.sigm)
for batch_inp, batch_target in test_loader:
batch_inp = batch_inp.cuda(non_blocking=True)
batch_target = batch_target.cuda(non_blocking=True)
for inp, target in zip(batch_inp, batch_target):
targets_idx = torch.nonzero(target)
inp = inp.unsqueeze(0)
for idx in targets_idx:
target = tuple(target.cpu().numpy(), idx)
attributions_lgc = layer_gradcam.attribute(inputs=inp, target=target)
The error I'm getting is that I cannot create a tuple using 2 arguments instead of 1.
I also tried passing the index to atributte():
attributions_lgc = layer_gradcam.attribute(inputs=inp, target=idx)
In this case, I have an input of inp: torch.Size([1, 3, 224, 224]) and a target of torch.Size([1]). But then I get the same error I was getting on the first snippet of code.
Any suggestions on how to solve this?
Thanks!
Hi @amandalucasp, thank you for the question! Do you get an error if you do a forward pass, i.e. linear_classifier(inp)?
Hi @99warriors! I don't get any errors doing a forward pass, as I was able to successfully train my classifier.