transformers
transformers copied to clipboard
[Trainer] Use of inspect for model.forward with torch.compile
Issue
In trainer
, the inspect
module is used to remove extraneous dataset columns.
https://github.com/huggingface/transformers/blob/60d51ef5123d949fd8c59cd4d3254e711541d278/src/transformers/trainer.py#L722-L728
However, torch.compile
modifies the signature of the forward function of the original model, so inspect.signature
is unable to correctly identify input arguments.
Possible Solution
If there is a way to recover the original arguments, that would be the cleanest solution. Otherwise, we could check if the model is compiled and modify the logic of the _set_signature_columns_if_needed
function appropriately, with perhaps added logging to the user that columns won't be dropped due to using torch.compile
.
System Information
- Python 3.8
- PyTorch 2.0
- transformers 4.27.1
Who can help?
@stas00 @sgugger
Information
- [ ] The official example scripts
- [X] My own modified scripts
Tasks
- [ ] An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - [X] My own task or dataset (give details below)
Reproduction
>>> import inspect, torch; from transformers import AutoModel
>>> model = AutoModel.from_pretrained("roberta-base")
>>> inspect.signature(model.forward)
<Signature (input_ids: Union[torch.Tensor, NoneType] = None, attention_mask: Union[torch.Tensor, NoneType] = None, token_type_ids: Union[torch.Tensor, NoneType] = None, position_ids: Union[torch.Tensor, NoneType] = None, head_mask: Union[torch.Tensor, NoneType] = None, inputs_embeds: Union[torch.Tensor, NoneType] = None, encoder_hidden_states: Union[torch.Tensor, NoneType] = None, encoder_attention_mask: Union[torch.Tensor, NoneType] = None, past_key_values: Union[List[torch.FloatTensor], NoneType] = None, use_cache: Union[bool, NoneType] = None, output_attentions: Union[bool, NoneType] = None, output_hidden_states: Union[bool, NoneType] = None, return_dict: Union[bool, NoneType] = None) -> Union[Tuple[torch.Tensor], transformers.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions]>
>>> opt_model = torch.compile(model)
>>> inspect.signature(opt_model.forward)
<Signature (*args, **kwargs)>
Expected behavior
The trainer should only drop unused columns, not all of them (which is what happens when it incorrectly registers args
and kwargs
as input arguments).