captum icon indicating copy to clipboard operation
captum copied to clipboard

Error using DeepLIFT with 1 CNN channel dimension

Open berkuva opened this issue 3 years ago • 2 comments

My CNN model takes input of size (batch, 1, 40000). When I try to use data with that size for the inputs argument of DeepLift(model).attribute(), I get TypeError: chunk(): argument 'input' (position 1) must be Tensor, not tuple. I think this is caused by the 1 channel dimension in the middle. How can I get around this?

berkuva avatar Nov 15 '22 05:11 berkuva

@berkuva could you paste the stack trace of the error and the code for context?

aobo-y avatar Nov 16 '22 19:11 aobo-y

@aobo-y Thanks for helping.

Yes, here's a snippet.

ConvNet(
  (layer1): Sequential(
    (0): Conv1d(1, 20, kernel_size=(4,), stride=(2,))
    (1): BatchNorm1d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )

model = ConvNet()

print(baselines1.shape, data.reshape(-1, 1, 40000).shape)
(torch.Size([1, 1, 40000]), torch.Size([108, 1, 40000]))

from captum.attr import DeepLift
dl = DeepLift(model)

dl.attribute(data.reshape(-1, 1, 40000), baselines=baselines, target=1)

Error trace:


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-58-5cbf0376db1b> in <module>
----> 1 dl.attribute(data.reshape(-1, 1, 40000), baselines=baselines1, target=1)

~/opt/anaconda3/envs/py3.6/lib/python3.6/site-packages/captum/log/__init__.py in wrapper(*args, **kwargs)
     33             @wraps(func)
     34             def wrapper(*args, **kwargs):
---> 35                 return func(*args, **kwargs)
     36 
     37             return wrapper

~/opt/anaconda3/envs/py3.6/lib/python3.6/site-packages/captum/attr/_core/deep_lift.py in attribute(self, inputs, baselines, target, additional_forward_args, return_convergence_delta, custom_attribution_func)
    347                 additional_forward_args,
    348             )
--> 349             gradients = self.gradient_func(wrapped_forward_func, inputs)
    350             if custom_attribution_func is None:
    351                 if self.multiplies_by_inputs:

~/opt/anaconda3/envs/py3.6/lib/python3.6/site-packages/captum/_utils/gradient.py in compute_gradients(forward_fn, inputs, target_ind, additional_forward_args)
    116     with torch.autograd.set_grad_enabled(True):
    117         # runs forward pass
--> 118         outputs = _run_forward(forward_fn, inputs, target_ind, additional_forward_args)
    119         assert outputs[0].numel() == 1, (
    120             "Target not provided when necessary, cannot"

~/opt/anaconda3/envs/py3.6/lib/python3.6/site-packages/captum/_utils/common.py in _run_forward(forward_func, inputs, target, additional_forward_args)
    426     forward_func_args = signature(forward_func).parameters
    427     if len(forward_func_args) == 0:
--> 428         output = forward_func()
    429         return output if target is None else _select_targets(output, target)
    430 

~/opt/anaconda3/envs/py3.6/lib/python3.6/site-packages/captum/attr/_core/deep_lift.py in forward_fn()
    387         def forward_fn():
    388             model_out = _run_forward(
--> 389                 forward_func, inputs, None, additional_forward_args
    390             )
    391             return _select_targets(

~/opt/anaconda3/envs/py3.6/lib/python3.6/site-packages/captum/_utils/common.py in _run_forward(forward_func, inputs, target, additional_forward_args)
    437         *(*inputs, *additional_forward_args)
    438         if additional_forward_args is not None
--> 439         else inputs
    440     )
    441     return _select_targets(output, target)

~/opt/anaconda3/envs/py3.6/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1121         if _global_forward_hooks or self._forward_hooks:
   1122             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
-> 1123                 hook_result = hook(self, input, result)
   1124                 if hook_result is not None:
   1125                     result = hook_result

~/opt/anaconda3/envs/py3.6/lib/python3.6/site-packages/captum/attr/_core/deep_lift.py in forward_hook(module, inputs, outputs)
    566 
    567         def forward_hook(module: Module, inputs: Tuple, outputs: Tensor):
--> 568             return torch.stack(torch.chunk(outputs, 2), dim=1)
    569 
    570         if isinstance(

TypeError: chunk(): argument 'input' (position 1) must be Tensor, not tuple

berkuva avatar Nov 16 '22 19:11 berkuva