Output `past_key_values` from `TextGenerationPipeline`.
Feature request
Currently, TextGenerationPipeline does not allow users to extract the past_key_values object from its output. It would be nice for us to be able to do so, so that we could then stream the intermediate text in chunks, whilst not having to recalculate the past_key_values after every time we yield.
Motivation
Runtime seems to skyrocket when streaming the results in pipeline using chunks. I believe this is due to the fact that we waste time having to recalculate past_key_values every time we make a call to pipeline().
Your contribution
Would be happy to help review code!
That's interesting, the current pipeline does not support chunking indeed. However, I think adding this would not be really hard cc @Narsil, would go in the generate_kwargs, only issue is that it is not going out.
That would be nice, but requires pretty much changing generate upside down and inside out.
This is what we have done here: https://github.com/huggingface/text-generation-inference which was required to get max performance out of bloom.
However, this is a pretty large endeavor which would mean the pipeline would basically redo the entire generate 's job.
Since generate is already quite complex, I'm hesitant to start such a thing.
Runtime seems to skyrocket when streaming the results in pipeline using chunks. I believe this is due to the fact that we waste time having to recalculate past_key_values every time we make a call to pipeline().
When you're generating, you shouldn't have to care about the leftmost part of a text, it will be ignored all the time, usually text generation models simply chunk the left most part of the text.
Isnt' that doable in your case ? Do you mind showing a script of what you're attempting to do ? This might help better understand what you're trying to achieve, and what are the possible options.
@Narsil thanks for the response! Here is an example of what I'd like to be able to do:
def stream_inference(input_dict):
text = input_dict["text_inputs"]
chunk_size = input_dict.pop("chunk_size", 10)
for _ in range(10):
generated_text = pipeline(text, max_new_tokens=chunk_size, use_cache=True)[0]["generated_text"]
yield generated_text
text += generated_text
What I've observed is that although we set use_cache=True, there is still the overhead of re-calculating the past_key_values every time we call pipeline() since it has been exited. Ideally, if we could extract past_key_values from the output of pipeline, then we could feed that back in the successive calls to address this issue.
Thoughts?
Pipeline is stateless, so it cannot keep the past_key_values and for you to send it again and again kind of defeats the purpose of a pipeline imo (since you can't batch anymore for starters, in general you're introducing some kind of state).
I can provide a script which kind of mimic what you want to do, it is pretty hacky, but the "clean" version is exactly how I said, it would need a major rewrite of some components.
https://github.com/huggingface/transformers/issues/17365#issuecomment-1152192715
Here is the adapted version without threading (which you should avoid if possible):
from transformers import pipeline
import torch
import threading
from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList
from queue import Queue
pipe = pipeline(model="gpt2", task="text-generation", device=0)
class Stream(StoppingCriteria):
def __init__(self):
self.prev_string = ""
def __call__(self, input_ids, scores) -> bool:
string = pipe.tokenizer.decode(input_ids[0])
# print(f"Total: {repr(string)}")
print(f"New: {repr(string[len(self.prev_string):])}")
self.prev_string = string
return False
for out in pipe("My initial text", max_new_tokens=10, stopping_criteria=[Stream()]):
print("Final result", out)
Does this work for you ?
@OlivierDehaene Tagging just because we were talking about the stream process in text-generation-inference :)
@Narsil Hmm, this does not address the issue of having to re-calculate past_key_values though between successive calls of pipe(), no?
Oh no that cannot change. But the idea, is that you can call it for a very long range (like max_new_tokens=100) which will use the past_key_values over and over without you having to deal with it. And you can still capture tokens as they are produced to send them to a viewer (here the stdout).
Doing anything with past_key_values at the pipeline level, is IMO too advanced for what pipelines are supposed to be. As it will break batching (which you most likely don't care about since you seem to be generating things live, but it's still a constraint on the pipeline itself).
The main goal of pipelines is to be useable by non-ML software engineers, past_key_values do require you to understand in quite a lot of details how things work internally. That's why IMO it's out of scope for pipeline.
If you really want full control, for instance to get resumable inference, you have to go at a lower level than the pipeline IMO. The code is not going to be so bad if you don't have batching to deal with A gist:
input_ids = tokenizer.encode("intiial string")
stopping_criteria = StoppingCriteriaList([EOSToken, etc...])
logits_processor = LogitsProcessorList[...]) # <--- For both of these, check out `generate` on what are those options and how to create them).
past_key_values = None
scores = None
while not stopping_criteria(input_ids, scores)
outputs = model.forward(input_ids, past_key_values)
past_key_values = outputs.past_key_values
logits = outputs.logits.softmax(dim=-1)
scores = logits_processor(logits)
input_ids = logits.argmax(dim=-1) # <---- choose whatever sampling strategy makes most sense
The code is not meant to be functional, but the end result should look something like it.
Since your problem space is likely to be simpler than the general transformers one, you can probably get rid of a sizeable chunk of complexity that we have to deal with, for beam_search, specific models, legacy code, batching, which don't really matter as much for you.
@Narsil Nice, I see what you are saying. Just for my own understanding -- is Stopping Criteria called per token produced?
Yes, it's intended goal is to decide when to stop generating tokens (hence the return type, false means continue generating, true means stop, iteration will stop when ANY criteria wants to stop).
@Narsil Thanks so much!
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.