peft icon indicating copy to clipboard operation
peft copied to clipboard

Make Prefix-tuning work for other Seq2Seq models like Whisper ASR model

Open louislau1129 opened this issue 1 year ago • 0 comments

Hi, for seq2seq model, I found the prefix-tuning method in peft currently only supports the seq2seq language model. This means it requires input_ids field to be the input to the encoder. However, if I want to fine-tune the whisper, a kind of speech2text seq2seq model using prefix-tuning, which accepts the input_features field as the encoder input, it will raise the error as no input_ids is provided. In principle, the prefix-tuning is to prepend several learnable virtual tokens to each transformer layer. If I understand correctly, it is implemented by providing the past_key_values to the base model as shown in following lines: https://github.com/huggingface/peft/blob/main/src/peft/peft_model.py#L896-L900

        if peft_config.peft_type == PeftType.PREFIX_TUNING:
            past_key_values = self.get_prompt(batch_size)
            return self.base_model(
                input_ids=input_ids, decoder_input_ids=decoder_input_ids, past_key_values=past_key_values, **kwargs
            )

If it is the case, it seems to not rely on the input forms, no matter it is input_ids or input_features. I would like to know if I can use prefix-tuning on top of the whisper model as expected by modifying the above code as follows:

        if peft_config.peft_type == PeftType.PREFIX_TUNING:
            past_key_values = self.get_prompt(batch_size)
            return self.base_model(
                input_features=input_features, decoder_input_ids=decoder_input_ids, past_key_values=past_key_values, **kwargs
            )

Or I still lose something to be modified accordingly. 😂 Thanks in advance.

louislau1129 avatar Apr 23 '23 10:04 louislau1129