captum icon indicating copy to clipboard operation
captum copied to clipboard

Making LLMAttribute work with BertForMultipleChoice models

Open rbelew opened this issue 7 months ago • 3 comments

🚀 Feature

Allow LLMAttribution goodness to be applied to BERT models for multiple choice tasks

Motivation

following up on suggestions from aobo-y

Pitch

Integrated gradient attribution techniques work over BertForMultipleChoice; it would be great if FeatureAblation / LLMAttribution did, too.

Alternatives

Two suggestions were made

First approach:

  • code
    fa = FeatureAblation(model) 
    llm_attr = LLMAttribution(fa, tokenizer)

    inp = TextTokenInput(promptTxt, tokenizer)
    
    attributions_fa = llm_attr.attribute(
                          inp,
                          target=targetIdxTensor,
                          additional_forward_args=dict(
                            token_type_ids=tst['token_type_ids'],
                            # position_ids=position_ids, 
                            attention_mask=tst['attention_mask'],
                            )
                          )

  • throws error:

      File ".../captumPerturb_min.py", line 160, in captumPerturbOne
      attributions_fa = llm_attr.attribute(
      ^^^^^^^^^^^^^^^^^^^
      File ".../site-packages/captum/attr/_core/llm_attr.py", line 674, in attribute
      cur_attr = self.attr_method.attribute(
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
      TypeError: captum.attr._core.feature_ablation.FeatureAblation.attribute() got multiple values for keyword argument 'additional_forward_args'
    
  • dropping additional_forward_args parameter gets farther, but throws:

      File ".../captumPerturb_min.py", line 160, in captumPerturbOne
      attributions_fa = llm_attr.attribute(
      ^^^^^^^^^^^^^^^^^^^
      File ".../site-packages/captum/attr/_core/llm_attr.py", line 674, in attribute
      cur_attr = self.attr_method.attribute(
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File ".../site-packages/captum/log/dummy_log.py", line 39, in wrapper
      return func(*args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^
      File ".../site-packages/captum/attr/_core/feature_ablation.py", line 288, in attribute
      initial_eval: Union[Tensor, Future[Tensor]] = _run_forward(
      ^^^^^^^^^^^^^
      File ".../site-packages/captum/_utils/common.py", line 588, in _run_forward
      output = forward_func(
      ^^^^^^^^^^^^^
      File ".../site-packages/captum/attr/_core/llm_attr.py", line 574, in _forward_func
      model_inputs = prep_inputs_for_generation(  # type: ignore
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File ".../site-packages/transformers/generation/utils.py", line 376, in prepare_inputs_for_generation
      raise NotImplementedError(
      NotImplementedError: A model class needs to define a `prepare_inputs_for_generation` method in order to use `.generate()`.
    
  • looking in _forward_func variables:

self.model.prepare_inputs_for_generation

    <bound method GenerationMixin.prepare_inputs_for_generation of BertForMultipleChoice(
	(bert): BertModel(
	(embeddings): BertEmbeddings(
	(word_embeddings): Embedding(30522, 768, padding_idx=0)
	(position_embeddings): Embedding(512, 768)
	(token_type_embeddings): Embedding(2, 768)
	(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
	(dropout): Dropout(p=0.1, inplace=False)
	)
	(encoder): BertEncoder(
	(layer): ModuleList(
	(0-11): 12 x BertLayer(
	(attention): BertAttention(
	(self): BertSdpaSelfAttention(
	(query): Linear(in_features=768, out_features=768, bias=True)
	(key): Linear(in_features=768, out_features=768, bias=True)
	(value): Linear(in_features=768, out_features=768, bias=True)
	(dropout): Dropout(p=0.1, inplace=False)
	)
	(output): BertSelfOutput(
	(dense): Linear(in_features=768, out_features=768, bias=True)
	(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
	(dropout): Dropout(p=0.1, inplace=False)
	)
	)
	(intermediate): BertIntermediate(
	(dense): Linear(in_features=768, out_features=3072, bias=True)
	(intermediate_act_fn): GELUActivation()
	)
	(output): BertOutput(
	(dense): Linear(in_features=3072, out_features=768, bias=True)
	(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
	(dropout): Dropout(p=0.1, inplace=False)
	)
	)
	)
	)
	(pooler): BertPooler(
	(dense): Linear(in_features=768, out_features=768, bias=True)
	(activation): Tanh()
	)
	)
	(dropout): Dropout(p=0.1, inplace=False)
	(classifier): Linear(in_features=768, out_features=1, bias=True)
	)>
  • model_inp: tensor, torch.Size([1, 112])

  • model_kwargs.keys()

      dict_keys(['attention_mask', 'cache_position', 'use_cache'])
    

Second approach

  • code
    def multChoice_forward(inputs, token_type_ids=None, position_ids=None, attention_mask=None, target=None):
        output = model(inputs, token_type_ids=token_type_ids,
                     position_ids=position_ids, attention_mask=attention_mask, )
        log_probs = torch.log_softmax(output.logits,1)
       # specify which choice's prob
        return log_probs[target]

    fa = FeatureAblation(multChoice_forward) 
    
    attributions_fa = fa.attribute(
                          tst['input_ids'], 
                          additional_forward_args=dict(
                            token_type_ids=tst['token_type_ids'], 
                            attention_mask=tst['attention_mask'], 
                            target=targetIdxTensor
                          )
                        )

  • throws

      File ".../captumPerturb_min.py", line 294, in main
      captumPerturbOne(model,tokenizer,tstDict,tstTarget)
      File ".../captumPerturb_min.py", line 184, in captumPerturbOne
      attributions_fa = fa.attribute(
      ^^^^^^^^^^^^^
      File ".../site-packages/captum/log/dummy_log.py", line 39, in wrapper
      return func(*args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^
      File ".../site-packages/captum/attr/_core/feature_ablation.py", line 288, in attribute
      initial_eval: Union[Tensor, Future[Tensor]] = _run_forward(
      ^^^^^^^^^^^^^
      File ".../site-packages/captum/_utils/common.py", line 588, in _run_forward
      output = forward_func(
      ^^^^^^^^^^^^^
      File ".../captumPerturb_min.py", line 175, in multChoice_forward
      output = model(inputs, token_type_ids=token_type_ids,
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File ".../site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File ".../site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
      return forward_call(*args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File ".../site-packages/transformers/models/bert/modeling_bert.py", line 1799, in forward
      token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
      ^^^^^^^^^^^^^^^^^^^
      AttributeError: 'dict' object has no attribute 'view'
    

This is too far into Transformer API-land for me to follow.

Additional context

Additional details in original issue https://github.com/pytorch/captum/issues/1523

rbelew avatar Mar 07 '25 23:03 rbelew