peft
peft copied to clipboard
Make Prefix-tuning work for other Seq2Seq models like Whisper ASR model
Hi, for seq2seq model, I found the prefix-tuning method in peft currently only supports the seq2seq language model. This means it requires input_ids
field to be the input to the encoder. However, if I want to fine-tune the whisper, a kind of speech2text seq2seq model using prefix-tuning, which accepts the input_features
field as the encoder input, it will raise the error as no input_ids
is provided. In principle, the prefix-tuning is to prepend several learnable virtual tokens to each transformer layer. If I understand correctly, it is implemented by providing the past_key_values
to the base model as shown in following lines:
https://github.com/huggingface/peft/blob/main/src/peft/peft_model.py#L896-L900
if peft_config.peft_type == PeftType.PREFIX_TUNING:
past_key_values = self.get_prompt(batch_size)
return self.base_model(
input_ids=input_ids, decoder_input_ids=decoder_input_ids, past_key_values=past_key_values, **kwargs
)
If it is the case, it seems to not rely on the input forms, no matter it is input_ids
or input_features
.
I would like to know if I can use prefix-tuning on top of the whisper model as expected by modifying the above code as follows:
if peft_config.peft_type == PeftType.PREFIX_TUNING:
past_key_values = self.get_prompt(batch_size)
return self.base_model(
input_features=input_features, decoder_input_ids=decoder_input_ids, past_key_values=past_key_values, **kwargs
)
Or I still lose something to be modified accordingly. 😂 Thanks in advance.