captum
captum copied to clipboard
Making LLMAttribute work with BertForMultipleChoice models
🚀 Feature
Allow LLMAttribution goodness to be applied to BERT models for multiple choice tasks
Motivation
following up on suggestions from aobo-y
Pitch
Integrated gradient attribution techniques work over BertForMultipleChoice; it would be great if FeatureAblation / LLMAttribution did, too.
Alternatives
Two suggestions were made
First approach:
- code
fa = FeatureAblation(model)
llm_attr = LLMAttribution(fa, tokenizer)
inp = TextTokenInput(promptTxt, tokenizer)
attributions_fa = llm_attr.attribute(
inp,
target=targetIdxTensor,
additional_forward_args=dict(
token_type_ids=tst['token_type_ids'],
# position_ids=position_ids,
attention_mask=tst['attention_mask'],
)
)
-
throws error:
File ".../captumPerturb_min.py", line 160, in captumPerturbOne attributions_fa = llm_attr.attribute( ^^^^^^^^^^^^^^^^^^^ File ".../site-packages/captum/attr/_core/llm_attr.py", line 674, in attribute cur_attr = self.attr_method.attribute( ^^^^^^^^^^^^^^^^^^^^^^^^^^^ TypeError: captum.attr._core.feature_ablation.FeatureAblation.attribute() got multiple values for keyword argument 'additional_forward_args'
-
dropping
additional_forward_argsparameter gets farther, but throws:File ".../captumPerturb_min.py", line 160, in captumPerturbOne attributions_fa = llm_attr.attribute( ^^^^^^^^^^^^^^^^^^^ File ".../site-packages/captum/attr/_core/llm_attr.py", line 674, in attribute cur_attr = self.attr_method.attribute( ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ".../site-packages/captum/log/dummy_log.py", line 39, in wrapper return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File ".../site-packages/captum/attr/_core/feature_ablation.py", line 288, in attribute initial_eval: Union[Tensor, Future[Tensor]] = _run_forward( ^^^^^^^^^^^^^ File ".../site-packages/captum/_utils/common.py", line 588, in _run_forward output = forward_func( ^^^^^^^^^^^^^ File ".../site-packages/captum/attr/_core/llm_attr.py", line 574, in _forward_func model_inputs = prep_inputs_for_generation( # type: ignore ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ".../site-packages/transformers/generation/utils.py", line 376, in prepare_inputs_for_generation raise NotImplementedError( NotImplementedError: A model class needs to define a `prepare_inputs_for_generation` method in order to use `.generate()`.
- looking in _forward_func variables:
self.model.prepare_inputs_for_generation
<bound method GenerationMixin.prepare_inputs_for_generation of BertForMultipleChoice(
(bert): BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(30522, 768, padding_idx=0)
(position_embeddings): Embedding(512, 768)
(token_type_embeddings): Embedding(2, 768)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0-11): 12 x BertLayer(
(attention): BertAttention(
(self): BertSdpaSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
(intermediate_act_fn): GELUActivation()
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)
(dropout): Dropout(p=0.1, inplace=False)
(classifier): Linear(in_features=768, out_features=1, bias=True)
)>
-
model_inp: tensor, torch.Size([1, 112])
-
model_kwargs.keys()
dict_keys(['attention_mask', 'cache_position', 'use_cache'])
Second approach
- code
def multChoice_forward(inputs, token_type_ids=None, position_ids=None, attention_mask=None, target=None):
output = model(inputs, token_type_ids=token_type_ids,
position_ids=position_ids, attention_mask=attention_mask, )
log_probs = torch.log_softmax(output.logits,1)
# specify which choice's prob
return log_probs[target]
fa = FeatureAblation(multChoice_forward)
attributions_fa = fa.attribute(
tst['input_ids'],
additional_forward_args=dict(
token_type_ids=tst['token_type_ids'],
attention_mask=tst['attention_mask'],
target=targetIdxTensor
)
)
-
throws
File ".../captumPerturb_min.py", line 294, in main captumPerturbOne(model,tokenizer,tstDict,tstTarget) File ".../captumPerturb_min.py", line 184, in captumPerturbOne attributions_fa = fa.attribute( ^^^^^^^^^^^^^ File ".../site-packages/captum/log/dummy_log.py", line 39, in wrapper return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File ".../site-packages/captum/attr/_core/feature_ablation.py", line 288, in attribute initial_eval: Union[Tensor, Future[Tensor]] = _run_forward( ^^^^^^^^^^^^^ File ".../site-packages/captum/_utils/common.py", line 588, in _run_forward output = forward_func( ^^^^^^^^^^^^^ File ".../captumPerturb_min.py", line 175, in multChoice_forward output = model(inputs, token_type_ids=token_type_ids, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ".../site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ".../site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ".../site-packages/transformers/models/bert/modeling_bert.py", line 1799, in forward token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None ^^^^^^^^^^^^^^^^^^^ AttributeError: 'dict' object has no attribute 'view'
This is too far into Transformer API-land for me to follow.
Additional context
Additional details in original issue https://github.com/pytorch/captum/issues/1523