transformers icon indicating copy to clipboard operation
transformers copied to clipboard

StoppingCritera for individual samples in batched input

Open Muennighoff opened this issue 2 years ago • 4 comments

Feature request

IIURC if I'm running batched generation and one sample in the batch has hit the stopping criteria but others have not, there is no way to be able to stop generations for only that sample. I.e. either I stop generating for all samples or the model will keep generating for all samples until all of them hit my stopping criteria.

It would be nice if instead to speed-up the generation, the model could only keep generating for the samples that have not yet hit the criteria. To keep tensor shapes consistent, it could e.g. append the padding token to the others.

A workaround is probably to stop if a single sample hits it, then filter my batch for all samples that have not yet hit the criteria and relaunch with only them. Lmk if there's a better workaround :)

Motivation

Faster generation

Your contribution

/

Muennighoff avatar Mar 23 '23 14:03 Muennighoff

cc @gante

sgugger avatar Mar 23 '23 14:03 sgugger

Hey @Muennighoff 👋

If I'm reading right, the sole purpose of the proposal is faster generation. In that case, implementing what you suggested is probably possible, but actually low impact. This is because the bottleneck in .generate() is the memory bandwidth associated with pulling the model weights all the way down to the compute cores, which is independent of the batch size 😢

Consider the script below:

from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import torch
import time

tok = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
print("Loading the model...")
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.bfloat16).to("cuda")

batch_size = 1
inputs = tok(["This cat is"] * batch_size, return_tensors="pt").to("cuda")

all_times = []
for i in tqdm(range(20)):
    start = time.time()
    gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=128, pad_token_id=model.config.eos_token_id)
    end = time.time()
    if i > 1:
        all_times.append(end - start)

print(f"Average time (batch_size={batch_size}): {sum(all_times) / len(all_times):.2f} seconds")

batch_size = 16
inputs = tok(["This cat is"] * batch_size, return_tensors="pt").to("cuda")

all_times = []
for i in tqdm(range(20)):
    start = time.time()
    gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=128, pad_token_id=model.config.eos_token_id)
    end = time.time()
    if i > 1:
        all_times.append(end - start)

print(f"Average time (batch_size={batch_size}): {sum(all_times) / len(all_times):.2f} seconds")

Running on my nvidia 3090:

  • batch_size=1 -> 4.19s
  • batch_size=16 -> 4.59s

Considering the philosophy for transformers, the potential speedup doesn't seem worth the implementation. Nevertheless, thank you for suggesting it! 🤗

gante avatar Mar 23 '23 15:03 gante

Hey @gante, thanks for getting back! I'm not sure what you mean by pulling the model weights all the way down to the compute cores?

In your example, all samples stop at the same time (i.e. after 128 new tokens) I think. I'm referring to cases where some samples may stop after e.g. 1 new token but others after e.g. 2000. In my case generating the additional tokens for samples that "have already stopped" increases my inference time from 1 hour to 10 hours, i.e. 9 hours are wasted on tokens that are not needed. In my case I'm better off using batch_size=1 due to this.

For example, consider the below StoppingCriteria, which stops as soon as any of the eof_strings are seen. I can either implement it as stopping when all samples in the batch of input_ids contain any of the eof_strings or when any contains them. In the former case, samples that have already hit a stop word in eof_strings will continue to be fed through the model & new tokens will be generated for them, as other samples have not yet hit a stop word. This causes unnecessary inference time. Instead, one could save time (9 hours i.e. 90% in my case) by only continuing to generate for the samples that have not yet hit the StoppingCriteria. Let me know if I'm being unclear!

class EndOfFunctionCriteria(StoppingCriteria):
    """Custom `StoppingCriteria` which checks if all generated functions in the batch are completed."""

    def __init__(self, start_length, eof_strings, tokenizer):
        self.start_length = start_length
        self.eof_strings = eof_strings
        self.tokenizer = tokenizer

    def __call__(self, input_ids, scores, **kwargs):
        """Returns true if all generated sequences contain any of the end-of-function strings."""
        decoded_generations = self.tokenizer.batch_decode(
            input_ids[:, self.start_length :]
        )
        done = []
        for decoded_generation in decoded_generations:
            done.append(
                any(
                    [
                        stop_string in decoded_generation
                        for stop_string in self.eof_strings
                    ]
                )
            )
        return all(done) # Stop when ALL sequences hit the stopping critera
        # return True if True in done # Stop when ANY sequence hits the stopping critera

Muennighoff avatar Mar 23 '23 16:03 Muennighoff

@Muennighoff Gotcha -- I now understand why you suggested this feature.

Before diving into solutions, let me understand the problem better. Normally, the generation time doesn't change much with the batch size (as I wrote above), meaning that generating the additional tokens is harmless. However, you are seeing a 10x difference 👀 This means I have a gap in my knowledge that I'd like to fill.

What is your hardware, and how are you using .generate()?

gante avatar Mar 23 '23 17:03 gante

@gante @Muennighoff +1 for this

ChatGPT use case: If I would like to generate until <|im_end|>, but it is not in the vocabulary as a complete token. So, I need to generate until the sequence ends with the needed substring.

Prompt (from https://github.com/openai/openai-python/blob/main/chatml.md):

<|im_start|>system
You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible.
Knowledge cutoff: 2021-09-01
Current date: 2023-03-01<|im_end|>
<|im_start|>user
How are you<|im_end|>
<|im_start|>assistant
I am doing well!<|im_end|>
<|im_start|>user
How are you now?<|im_end|>
<|im_start|>assistant

I assume all the magic is right here: https://github.com/huggingface/transformers/blob/15641892985b1d77acc74c9065c332cd7c3f7d7f/src/transformers/generation/utils.py#L2045 Belive a quick fix is to run every criterion on each sample in the batch, so all current users of stopping criteria will not be harmed by this update.

Let me know if I can help with this 🤗

AlekseyKorshuk avatar Apr 06 '23 07:04 AlekseyKorshuk

@AlekseyKorshuk Currently, you can craft custom stopping criteria and pass it to the .generate() call. See this file for examples. After a given input row hits the criteria, it will only append pad tokens to the input, which you can easily filter out.

What is being requested, not running inference at all on the rows where the stopping criteria matches, is relatively expensive to build while maintaining retrocompatibility. Please note that, even if it is built, the output will also contain the pad tokens (as described above). I haven't seen any proof that the speedups are worth the engineering effort of our small team 🤗

If anyone can show me a clear case where the generation time grows quickly with the batch size, I'll gladly bump its priority. I am unaware of a situation where this applies (except for beam search on pytorch, but that's due to an issue in the beam search implementation).

gante avatar Apr 06 '23 09:04 gante

@gante Thank you, I checked examples, but it looks like it returns True/False for a complete batch. And a quick test showed the following:

import torch

class StoppingCriteriaSub(StoppingCriteria):

    def __init__(self, stops = [], encounters=1):
        super().__init__()
        self.stops = [stop for stop in stops]

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        print(input_ids)
        for stop in self.stops:
            if torch.all((stop == input_ids[0][-len(stop):])).item():
                return True

        return False


stop_words = ["<human>:", "<bot>:"]
stop_words_ids = [tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])

inputs = tokenizer(["<human>: How are you?\n<bot>:", "<human>: Why?\n<bot>:"], return_tensors='pt',padding=True)
model.generate(**inputs, stopping_criteria=stopping_criteria, max_new_tokens=32)

And the print returns the following:

tensor([[   27, 10734, 31175,  1374,   389,   345,    30,   198,    27, 13645,
         31175,   314],
        [   27, 10734, 31175,  4162,    30,   198,    27, 13645, 31175, 50256,
         50256,   464]])

So my question is: how can I make sure that in the end all samples from the batch will have a substring from stop_words (excluding special tokens)?

AlekseyKorshuk avatar Apr 07 '23 21:04 AlekseyKorshuk

@AlekseyKorshuk

but it looks like it returns True/False for a complete batch

Correct, the stopping conditions operate on a whole batch level. Changing it to a row-level is not on our short-term plans (and is, in essence, what the original issue here is about :) )

So my question is: how can I make sure that in the end all samples from the batch will have a substring from stop_words (excluding special tokens)?

I'm not sure if I got your question -- would you like to ensure that all rows in the batch generate stop_words at least once?

gante avatar Apr 13 '23 18:04 gante

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar May 08 '23 15:05 github-actions[bot]

@gante

It seems behaviour is similar when you run beam search with stopping criteria, you cannot reject some of the beams and accept some of them. Would there be a workaround to achieve this?

Praful932 avatar Aug 19 '23 11:08 Praful932

@Praful932 it depends on your exact use case, but you may be able to write a custom logits processor that behaves as a soft stopping criteria for beam methods, by setting all next scores to a large negative value (if you want to discard the beam) OR forcing an EOS token (if you want to accept the beam as finalized) when your condition triggers :)

gante avatar Aug 21 '23 18:08 gante

Thank you, this helps :)

Praful932 avatar Aug 21 '23 19:08 Praful932

@Muennighoff @gante

Thank you Muennighoff. Your code of implement of stopping criteria is very useful. Thanks! Thank you gante. Forcing an EOS token is a good approach. Thanks!

To address the issue in beam search where, assuming a beam size of 2, there are two beams A and B. If beam A has already met the stopping criteria but the other beams are still generating, in such cases, even though for beam A the stopping criteria has been encountered, generation will continue.

An approach is to leverage the model's automatic behavior of adding <pad> after <eos>. Both after <eos> and <pad>, the model will add <pad>.

After beam A has met the stopping criteria, you can directly append an <eos> token at the end of beam A. This way, the model will automatically add <pad> to the content following beam A, while it continues to generate beam B until A and B both meet the stopping criteria. p.s. refering to and , I am using gemma-2b-it.

RiverTre avatar Mar 05 '24 13:03 RiverTre

@RiverTre We are working on per-row stopping criteria at the moment (cc @zucchini-nlp) :)

Depending on our future bandwidth, we may work on dynamically reducing the batch size according to the stopping criteria, to save time on compute-constrained cases.

gante avatar Mar 05 '24 13:03 gante