captum
captum copied to clipboard
Issues with BERT-type model
I am working with the LayoutLMv2 model in huggingface (https://huggingface.co/transformers/model_doc/layoutlm.html). Works fine with performing a forward pass, but get a dimensionality error related to the embeddings when I try to use it in Captum for explainability. Note that LayoutLM (first version of the model) gives no issues in the same context. Also, I realize that this model needs to be finetuned. This is just supposed to be a proof-of-concept usage.
Here is my code:
from PIL import Image, ImageDraw, ImageFont
from transformers import LayoutLMv2FeatureExtractor, LayoutLMv2TokenizerFast, LayoutLMv2Processor, LayoutLMv2ForSequenceClassification
from captum.attr._utils.input_layer_wrapper import ModelInputWrapper
from captum.attr import LayerIntegratedGradients, TokenReferenceBase
import torch
import torchvision
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_rgb = Image.open("IMAGE.jpg").convert("RGB")
processor = LayoutLMv2Processor.from_pretrained('microsoft/layoutlmv2-base-uncased')
model = LayoutLMv2ForSequenceClassification.from_pretrained('microsoft/layoutlmv2-base-uncased')
tokenizer = LayoutLMv2TokenizerFast.from_pretrained("microsoft/layoutlmv2-base-uncased")
encoding = processor(image_rgb, return_tensors="pt")
input_ids = encoding['input_ids']
token_type_ids = encoding['token_type_ids']
attention_mask = encoding['attention_mask']
bbox = encoding['bbox']
model_layered = ModelInputWrapper(model)
outputs = model_layered(**encoding)
pred, answer_idx = F.softmax(outputs.logits, dim=1).data.cpu().max(dim=1)
def batch_predict(input_ids, image, bbox, attention_mask, token_type_ids):
model_layered.eval()
outputs = model_layered(input_ids=input_ids,
image=image,
bbox=bbox,
attention_mask=attention_mask,
token_type_ids=token_type_ids)
logits = outputs.logits
probs = F.softmax(logits, dim=1)
return probs
attr = LayerIntegratedGradients(batch_predict,
[model_layered.module.layoutlmv2.embeddings.word_embeddings,
model_layered.module.layoutlmv2.embeddings.position_embeddings,
model_layered.module.layoutlmv2.embeddings.x_position_embeddings,
model_layered.module.layoutlmv2.embeddings.y_position_embeddings,
model_layered.module.layoutlmv2.embeddings.h_position_embeddings,
model_layered.module.layoutlmv2.embeddings.w_position_embeddings,
model_layered.module.layoutlmv2.embeddings.token_type_embeddings,])
# Generate reference for tokens
token_reference = TokenReferenceBase(reference_token_idx=tokenizer.pad_token_id)
text_reference_indices = token_reference.generate_reference(len(encoding['input_ids'][0]), device=device).unsqueeze(0)
baselines = text_reference_indices
attributions = attr.attribute(inputs=encoding['input_ids'],
additional_forward_args=(encoding['image'],
encoding['bbox'],
encoding['attention_mask'],
encoding['token_type_ids']),
baselines=baselines,
target=answer_idx,
n_steps=5)
And the error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-1-00bdf4f97de3> in <module>()
60 baselines=baselines,
61 target=answer_idx,
---> 62 n_steps=5)
/home/natbarkas/anaconda3/envs/explainability-v2/lib/python3.6/site-packages/captum/log/__init__.py in wrapper(*args, **kwargs)
33 @wraps(func)
34 def wrapper(*args, **kwargs):
---> 35 return func(*args, **kwargs)
36
37 return wrapper
/home/natbarkas/anaconda3/envs/explainability-v2/lib/python3.6/site-packages/captum/attr/_core/layer/layer_integrated_gradients.py in attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, internal_batch_size, return_convergence_delta, attribute_to_layer_input)
496 method=method,
497 internal_batch_size=internal_batch_size,
--> 498 return_convergence_delta=False,
499 )
500
/home/natbarkas/anaconda3/envs/explainability-v2/lib/python3.6/site-packages/captum/attr/_core/integrated_gradients.py in attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, internal_batch_size, return_convergence_delta)
290 additional_forward_args=additional_forward_args,
291 n_steps=n_steps,
--> 292 method=method,
293 )
294
/home/natbarkas/anaconda3/envs/explainability-v2/lib/python3.6/site-packages/captum/attr/_core/integrated_gradients.py in _attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, step_sizes_and_alphas)
353 inputs=scaled_features_tpl,
354 target_ind=expanded_target,
--> 355 additional_forward_args=input_additional_args,
356 )
357
/home/natbarkas/anaconda3/envs/explainability-v2/lib/python3.6/site-packages/captum/attr/_core/layer/layer_integrated_gradients.py in gradient_func(forward_fn, inputs, target_ind, additional_forward_args)
464
465 output = _run_forward(
--> 466 self.forward_func, tuple(), target_ind, additional_forward_args
467 )
468 finally:
/home/natbarkas/anaconda3/envs/explainability-v2/lib/python3.6/site-packages/captum/_utils/common.py in _run_forward(forward_func, inputs, target, additional_forward_args)
451 *(*inputs, *additional_forward_args)
452 if additional_forward_args is not None
--> 453 else inputs
454 )
455 return _select_targets(output, target)
<ipython-input-1-00bdf4f97de3> in batch_predict(input_ids, image, bbox, attention_mask, token_type_ids)
34 bbox=bbox,
35 attention_mask=attention_mask,
---> 36 token_type_ids=token_type_ids)
37 logits = outputs.logits
38 probs = F.softmax(logits, dim=1)
/home/natbarkas/anaconda3/envs/explainability-v2/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
/home/natbarkas/anaconda3/envs/explainability-v2/lib/python3.6/site-packages/captum/attr/_utils/input_layer_wrapper.py in forward(self, *args, **kwargs)
74 kwargs[arg_name] = self.input_maps[arg_name](kwargs[arg_name])
75
---> 76 return self.module(*tuple(args), **kwargs)
/home/natbarkas/anaconda3/envs/explainability-v2/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
/home/natbarkas/anaconda3/envs/explainability-v2/lib/python3.6/site-packages/transformers/models/layoutlmv2/modeling_layoutlmv2.py in forward(self, input_ids, bbox, image, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict)
1053 output_attentions=output_attentions,
1054 output_hidden_states=output_hidden_states,
-> 1055 return_dict=return_dict,
1056 )
1057 if input_ids is not None:
/home/natbarkas/anaconda3/envs/explainability-v2/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
/home/natbarkas/anaconda3/envs/explainability-v2/lib/python3.6/site-packages/transformers/models/layoutlmv2/modeling_layoutlmv2.py in forward(self, input_ids, bbox, image, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict)
893 token_type_ids=token_type_ids,
894 position_ids=position_ids,
--> 895 inputs_embeds=inputs_embeds,
896 )
897
/home/natbarkas/anaconda3/envs/explainability-v2/lib/python3.6/site-packages/transformers/models/layoutlmv2/modeling_layoutlmv2.py in _calc_text_embeddings(self, input_ids, bbox, position_ids, token_type_ids, inputs_embeds)
754 token_type_embeddings = self.embeddings.token_type_embeddings(token_type_ids)
755
--> 756 embeddings = inputs_embeds + position_embeddings + spatial_position_embeddings + token_type_embeddings
757 embeddings = self.embeddings.LayerNorm(embeddings)
758 embeddings = self.embeddings.dropout(embeddings)
RuntimeError: The size of tensor a (44) must match the size of tensor b (49) at non-singleton dimension 1
These are the versions of the packages I am using:
transformers==4.11.2
captum=0.4.0
torch==1.7.0
torchvision==0.8.1
Hi @nataliebarcickikas - it seems like the error fails during the forward call to batch_predict
. What happens if you call batch_predict
directly, without using LayerIntegratedGradients
?
Directly calling batch_predict
causes no issues:
batch_predict(encoding['input_ids'],
encoding['image'],
encoding['bbox'],
encoding['attention_mask'],
encoding['token_type_ids'])
Output:
tensor([[0.4553, 0.5447]], grad_fn=<SoftmaxBackward>)
Could you print out the dimensions of all the embeddings in that line 756 for the forward call to batch_predict
as well as to attr
?
I print the dimensions along with the four individual components of it:
In the forward call:
batch_predict(encoding['input_ids'],
encoding['image'],
encoding['bbox'],
encoding['attention_mask'],
encoding['token_type_ids'])
Input embeddings: torch.Size([1, 44, 768])
Position embeddings: torch.Size([1, 44, 768])
Spatial position embeddings: torch.Size([1, 44, 768])
Token type embeddings: torch.Size([1, 44, 768])
torch.Size([1, 44, 768])
tensor([[0.4942, 0.5058]], grad_fn=<SoftmaxBackward>)
In the attr call:
attributions = attr.attribute(inputs=encoding['input_ids'],
additional_forward_args=(encoding['image'],
encoding['bbox'],
encoding['attention_mask'],
encoding['token_type_ids']),
baselines=baselines,
target=answer_idx,
n_steps=1)
Input embeddings: torch.Size([1, 44, 768])
Position embeddings: torch.Size([1, 44, 768])
Spatial position embeddings: torch.Size([1, 44, 768])
Token type embeddings: torch.Size([1, 44, 768])
torch.Size([1, 44, 768])
Input embeddings: torch.Size([1, 44, 768])
Position embeddings: torch.Size([1, 44, 768])
Spatial position embeddings: torch.Size([1, 44, 768])
Token type embeddings: torch.Size([1, 44, 768])
torch.Size([1, 44, 768])
Input embeddings: torch.Size([1, 44, 768])
Position embeddings: torch.Size([1, 49, 768])
Spatial position embeddings: torch.Size([1, 49, 768])
Token type embeddings: torch.Size([1, 44, 768])
And then the runtime error occurs:
RuntimeError: The size of tensor a (44) must match the size of tensor b (49) at non-singleton dimension 1
(I left this comment originally on #904)
I've this exact same issue with a custom Bert based model and I've traced it back to the Captum hook being called in line 1072 of the source code for torch.nn.Module
(i.e. during the forward call of you model). The hook being called that causes this issue is layer_integrated_gradients.layer_forward_hook
. It appears that the cached value in scattered_inputs_dict
is being returned no matter what because the hook is being called at the start of the wrapper module's forward method, and not being reset mid call if weights are shared.
# num_current_tokens = 50, num_prev_tokens = 71
input_ids.shape # torch.Size([1, 50])
self.word_embeddings(input_ids).shape # torch.Size([1, 71, 768]) instead of torch.Size([1, 50, 768])
For context, my model shares weights and we call the same model twice within forward():
def _forward(...):
....
outputs_text = self.bert(input_ids=input_ids_text, attention_mask=attention_mask_text, **kwargs)
outputs_context = self.bert(input_ids=input_ids_context, attention_mask=attention_mask_context, **kwargs)
...
return outputs
@99warriors do you guys have a time estimate? If not I'm happy to fork and go from a starting point
@chrisdoyleIE Thank you for investigating this. We have had discussions over how to fix this problem (perhaps expand scattered_inputs_dict
to cache the result of multiple forward calls of the same module), but this discussion is still on-going. Any suggestions you had would be most helpful / welcome!
For now, we would recommend, as a work-around, to avoid calling the same module multiple times within the same forward pass (and instead creating copies of the module, which can all share weights), and adding a warning related to this is an immediate task we can tackle.
That'll do, many thanks!
Is there any update regarding this topic?
I have a similar issue: I call the embedding function in my model multiple times to split an input into several chunks (it's a hierarchical model9 where I have to get [CLS]/[PAD]/[SEP] embeddings in between:
...
sep_embed = self.bert.embeddings(torch.tensor([[4]], dtype=torch.long, device=self.device))[0][0]
pad_embed = self.bert.embeddings(torch.tensor([[0]], dtype=torch.long, device=self.device))[0][0]
cls_embed =self.bert.embeddings(torch.tensor([[3]], dtype=torch.long, device=self.device))[0][0]
...
Then, of course, I get wrong shapes from the scattered_inputs_dic
(or saved_layer
?).