soft-prompt-tuning icon indicating copy to clipboard operation
soft-prompt-tuning copied to clipboard

How to generate text?

Open luke-thorburn opened this issue 3 years ago • 10 comments

Could someone please share some example code for how to generate text using a model with a soft prompt?

I have finetuned a soft prompt model (as implemented in this repo), however when I try to use the .generate(...) method from the Huggingface transformers library, I get an error in the forward pass of the model about mismatched tensor sizes.

luke-thorburn avatar Dec 09 '21 15:12 luke-thorburn

Hey, i tried out text generation and there is an issue when caching during generation because of how the learned embedding is coded, you probably could fix it if you check if tensor shape is 0 and assume you are doing caching

the whole solution is a bit hacky, but i didn't really want to fork hugging face and change the code so this did the job for me

anyway you can fix it by turning caching it off

inputs = tokenizer("may the force", return_tensors="pt")

# need to pad attention_mask and input_ids to be full seq_len + n_learned_tokens
# even though it does not matter what you pad input_ids with, it's just to make HF happy
inputs['input_ids'] = torch.cat([torch.full((1,n_tokens), 50256), inputs['input_ids']], 1)
inputs['attention_mask'] = torch.cat([torch.full((1,n_tokens), 1), inputs['attention_mask']], 1)

tokens_to_generate = 10

outputs = model.generate(**inputs, max_length=inputs['input_ids'].size(1)+tokens_to_generate, use_cache=False)

kipgparker avatar Dec 12 '21 18:12 kipgparker

Works well enough for me! Thanks for looking into it.

luke-thorburn avatar Dec 13 '21 20:12 luke-thorburn

Hi, any thoughts how to use this for a BART model? BART automatically right shifts the labels to create decoder_input_ids, which makes the soft embedding available only to the encoder and not to the decoder. How would I proceed to make soft embeddings available to the decoder_input as well? I modified the forward call like this to automatically add soft tokens only if it is present in the input in order to bypass the decoder_input_id:

n_tokens = torch.sum(tokens[0] == 50256).item(). # 50256 is the id i'm using to represent the prompt tokens
input_embedding = self.wte(tokens[:, n_tokens:])
learned_embedding = self.learned_embedding[:n_tokens, :].repeat(
      input_embedding.size(0), 1, 1
)
return torch.cat([learned_embedding, input_embedding], 1)

However, doing this the decoder never gets to "see" the soft embeddings (as labels are used as input to the decoder). Would you recommend padding the labels with the special tokens too? If so, wouldn't the decoder collapse during generation?

koustuvsinha avatar Apr 03 '22 19:04 koustuvsinha

@koustuvsinha Great question. Did you ever figure this out?

JosephGatto avatar May 25 '22 18:05 JosephGatto

Hi everyone here. I'm trying to generate current response accroding to dialogue context using soft prompt, how can I use this codes to generate? Thanks the codes @kipgparker provided, but it seems cannot train and save the soft prompt learnable weights, how can I save the weights for generate text and train the model successfully?

huangfu170 avatar Apr 17 '23 06:04 huangfu170

I don't have time to provide detailed help, but the complete code I used is in this repository:

https://github.com/Hunt-Laboratory/language-model-optimization

It might point you in the right direction.

luke-thorburn avatar Apr 17 '23 07:04 luke-thorburn

Thank you for your kind help, Wish you all the best.

奕剑楼外、听风雨 @.***

 

huangfu170 avatar Apr 17 '23 07:04 huangfu170

Hi, I have seen your codes in https://github.com/Hunt-Laboratory/language-model-optimization, it helps me a lot. but I still don't understand why I should pad the input to the len(input)+n_tokens in the model.generate(...) function. In the train step, the input_ids and the attention_mask are sent to the model without pad n_tokens, but there''re padding at generate step. I want to figure out the reason.

Any help would be great for me. Wish you all the best. yxhuangfu

奕剑楼外、听风雨 @.***

 

huangfu170 avatar Apr 19 '23 10:04 huangfu170

I still get this issue on the forward pass even after using use_cache=False. I am using a T5 model for summarization (text generation task). Has anyone tried this with T5 models?

brihat9135 avatar May 17 '23 20:05 brihat9135

@koustuvsinha have you try to apply this on a BART model ? how it works? thx!

EveningLin avatar Oct 20 '23 11:10 EveningLin