captum
captum copied to clipboard
Error using DeepLIFT with 1 CNN channel dimension
My CNN model takes input of size (batch, 1, 40000). When I try to use data with that size for the inputs argument of DeepLift(model).attribute(), I get TypeError: chunk(): argument 'input' (position 1) must be Tensor, not tuple. I think this is caused by the 1 channel dimension in the middle. How can I get around this?
@berkuva could you paste the stack trace of the error and the code for context?
@aobo-y Thanks for helping.
Yes, here's a snippet.
ConvNet(
(layer1): Sequential(
(0): Conv1d(1, 20, kernel_size=(4,), stride=(2,))
(1): BatchNorm1d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
model = ConvNet()
print(baselines1.shape, data.reshape(-1, 1, 40000).shape)
(torch.Size([1, 1, 40000]), torch.Size([108, 1, 40000]))
from captum.attr import DeepLift
dl = DeepLift(model)
dl.attribute(data.reshape(-1, 1, 40000), baselines=baselines, target=1)
Error trace:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-58-5cbf0376db1b> in <module>
----> 1 dl.attribute(data.reshape(-1, 1, 40000), baselines=baselines1, target=1)
~/opt/anaconda3/envs/py3.6/lib/python3.6/site-packages/captum/log/__init__.py in wrapper(*args, **kwargs)
33 @wraps(func)
34 def wrapper(*args, **kwargs):
---> 35 return func(*args, **kwargs)
36
37 return wrapper
~/opt/anaconda3/envs/py3.6/lib/python3.6/site-packages/captum/attr/_core/deep_lift.py in attribute(self, inputs, baselines, target, additional_forward_args, return_convergence_delta, custom_attribution_func)
347 additional_forward_args,
348 )
--> 349 gradients = self.gradient_func(wrapped_forward_func, inputs)
350 if custom_attribution_func is None:
351 if self.multiplies_by_inputs:
~/opt/anaconda3/envs/py3.6/lib/python3.6/site-packages/captum/_utils/gradient.py in compute_gradients(forward_fn, inputs, target_ind, additional_forward_args)
116 with torch.autograd.set_grad_enabled(True):
117 # runs forward pass
--> 118 outputs = _run_forward(forward_fn, inputs, target_ind, additional_forward_args)
119 assert outputs[0].numel() == 1, (
120 "Target not provided when necessary, cannot"
~/opt/anaconda3/envs/py3.6/lib/python3.6/site-packages/captum/_utils/common.py in _run_forward(forward_func, inputs, target, additional_forward_args)
426 forward_func_args = signature(forward_func).parameters
427 if len(forward_func_args) == 0:
--> 428 output = forward_func()
429 return output if target is None else _select_targets(output, target)
430
~/opt/anaconda3/envs/py3.6/lib/python3.6/site-packages/captum/attr/_core/deep_lift.py in forward_fn()
387 def forward_fn():
388 model_out = _run_forward(
--> 389 forward_func, inputs, None, additional_forward_args
390 )
391 return _select_targets(
~/opt/anaconda3/envs/py3.6/lib/python3.6/site-packages/captum/_utils/common.py in _run_forward(forward_func, inputs, target, additional_forward_args)
437 *(*inputs, *additional_forward_args)
438 if additional_forward_args is not None
--> 439 else inputs
440 )
441 return _select_targets(output, target)
~/opt/anaconda3/envs/py3.6/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1121 if _global_forward_hooks or self._forward_hooks:
1122 for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
-> 1123 hook_result = hook(self, input, result)
1124 if hook_result is not None:
1125 result = hook_result
~/opt/anaconda3/envs/py3.6/lib/python3.6/site-packages/captum/attr/_core/deep_lift.py in forward_hook(module, inputs, outputs)
566
567 def forward_hook(module: Module, inputs: Tuple, outputs: Tensor):
--> 568 return torch.stack(torch.chunk(outputs, 2), dim=1)
569
570 if isinstance(
TypeError: chunk(): argument 'input' (position 1) must be Tensor, not tuple