transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Output `past_key_values` from `TextGenerationPipeline`.

Open gilljon opened this issue 2 years ago • 10 comments

Feature request

Currently, TextGenerationPipeline does not allow users to extract the past_key_values object from its output. It would be nice for us to be able to do so, so that we could then stream the intermediate text in chunks, whilst not having to recalculate the past_key_values after every time we yield.

Motivation

Runtime seems to skyrocket when streaming the results in pipeline using chunks. I believe this is due to the fact that we waste time having to recalculate past_key_values every time we make a call to pipeline().

Your contribution

Would be happy to help review code!

gilljon avatar Jan 24 '23 05:01 gilljon

That's interesting, the current pipeline does not support chunking indeed. However, I think adding this would not be really hard cc @Narsil, would go in the generate_kwargs, only issue is that it is not going out.

ArthurZucker avatar Jan 25 '23 12:01 ArthurZucker

That would be nice, but requires pretty much changing generate upside down and inside out.

This is what we have done here: https://github.com/huggingface/text-generation-inference which was required to get max performance out of bloom.

However, this is a pretty large endeavor which would mean the pipeline would basically redo the entire generate 's job. Since generate is already quite complex, I'm hesitant to start such a thing.

Runtime seems to skyrocket when streaming the results in pipeline using chunks. I believe this is due to the fact that we waste time having to recalculate past_key_values every time we make a call to pipeline().

When you're generating, you shouldn't have to care about the leftmost part of a text, it will be ignored all the time, usually text generation models simply chunk the left most part of the text.

Isnt' that doable in your case ? Do you mind showing a script of what you're attempting to do ? This might help better understand what you're trying to achieve, and what are the possible options.

Narsil avatar Jan 25 '23 16:01 Narsil

@Narsil thanks for the response! Here is an example of what I'd like to be able to do:

def stream_inference(input_dict):
    text = input_dict["text_inputs"]
    chunk_size = input_dict.pop("chunk_size", 10)
    for _ in range(10):
        generated_text = pipeline(text, max_new_tokens=chunk_size, use_cache=True)[0]["generated_text"]
        yield generated_text
        text += generated_text

What I've observed is that although we set use_cache=True, there is still the overhead of re-calculating the past_key_values every time we call pipeline() since it has been exited. Ideally, if we could extract past_key_values from the output of pipeline, then we could feed that back in the successive calls to address this issue.

Thoughts?

gilljon avatar Jan 25 '23 18:01 gilljon

Pipeline is stateless, so it cannot keep the past_key_values and for you to send it again and again kind of defeats the purpose of a pipeline imo (since you can't batch anymore for starters, in general you're introducing some kind of state).

I can provide a script which kind of mimic what you want to do, it is pretty hacky, but the "clean" version is exactly how I said, it would need a major rewrite of some components.

https://github.com/huggingface/transformers/issues/17365#issuecomment-1152192715

Here is the adapted version without threading (which you should avoid if possible):

from transformers import pipeline
import torch
import threading
from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList
from queue import Queue


pipe = pipeline(model="gpt2", task="text-generation", device=0)


class Stream(StoppingCriteria):
    def __init__(self):
        self.prev_string = ""

    def __call__(self, input_ids, scores) -> bool:
        string = pipe.tokenizer.decode(input_ids[0])
        # print(f"Total: {repr(string)}")
        print(f"New: {repr(string[len(self.prev_string):])}")
        self.prev_string = string
        return False


for out in pipe("My initial text", max_new_tokens=10, stopping_criteria=[Stream()]):
    print("Final result", out)

Does this work for you ?

Narsil avatar Jan 26 '23 09:01 Narsil

@OlivierDehaene Tagging just because we were talking about the stream process in text-generation-inference :)

Narsil avatar Jan 26 '23 09:01 Narsil

@Narsil Hmm, this does not address the issue of having to re-calculate past_key_values though between successive calls of pipe(), no?

gilljon avatar Jan 26 '23 17:01 gilljon

Oh no that cannot change. But the idea, is that you can call it for a very long range (like max_new_tokens=100) which will use the past_key_values over and over without you having to deal with it. And you can still capture tokens as they are produced to send them to a viewer (here the stdout).

Doing anything with past_key_values at the pipeline level, is IMO too advanced for what pipelines are supposed to be. As it will break batching (which you most likely don't care about since you seem to be generating things live, but it's still a constraint on the pipeline itself). The main goal of pipelines is to be useable by non-ML software engineers, past_key_values do require you to understand in quite a lot of details how things work internally. That's why IMO it's out of scope for pipeline.

If you really want full control, for instance to get resumable inference, you have to go at a lower level than the pipeline IMO. The code is not going to be so bad if you don't have batching to deal with A gist:


input_ids = tokenizer.encode("intiial string")
stopping_criteria = StoppingCriteriaList([EOSToken, etc...])
logits_processor = LogitsProcessorList[...]) # <--- For both of these, check out `generate` on what are those options and how to create them).
past_key_values = None
scores = None

while not stopping_criteria(input_ids, scores)
    outputs = model.forward(input_ids, past_key_values)
    past_key_values = outputs.past_key_values
    logits = outputs.logits.softmax(dim=-1)
    scores = logits_processor(logits)
    input_ids = logits.argmax(dim=-1) # <---- choose whatever sampling strategy makes most sense

The code is not meant to be functional, but the end result should look something like it.

Since your problem space is likely to be simpler than the general transformers one, you can probably get rid of a sizeable chunk of complexity that we have to deal with, for beam_search, specific models, legacy code, batching, which don't really matter as much for you.

Narsil avatar Jan 26 '23 17:01 Narsil

@Narsil Nice, I see what you are saying. Just for my own understanding -- is Stopping Criteria called per token produced?

gilljon avatar Jan 26 '23 19:01 gilljon

Yes, it's intended goal is to decide when to stop generating tokens (hence the return type, false means continue generating, true means stop, iteration will stop when ANY criteria wants to stop).

Narsil avatar Jan 26 '23 19:01 Narsil

@Narsil Thanks so much!

gilljon avatar Jan 26 '23 19:01 gilljon

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Feb 23 '23 15:02 github-actions[bot]