DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

[BUG] Incorrect Model Outputs When Using Beam Search

Open zelcookie opened this issue 1 year ago • 19 comments

Describe the bug When I use kernel injection I get worse generation results than when using transformers without DeepSpeed. I don't know if the results should be the same, but they are not only not the same, but even worse.

I saw this two issues marked as closed: https://github.com/microsoft/DeepSpeed/issues/2048 https://github.com/microsoft/DeepSpeed/issues/2230

But I use version after this fix(https://github.com/microsoft/DeepSpeed/pull/2489) and still have a problem

To Reproduce

Huggingface:

import random
import os
import numpy as np
import torch

def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:2"
    os.environ["PL_GLOBAL_SEED"] = str(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

set_random_seed(42)

params = {
 'num_beams': 2,
 'do_sample': False,
 'max_new_tokens': 65,
 'use_cache': True,
 'no_repeat_ngram_size': 5,
 'num_return_sequences': 1}

from transformers import AutoTokenizer, AutoModelForCausalLM

DEVICE = torch.device("cuda:0")
name = "EleutherAI/gpt-j-6B"
model = AutoModelForCausalLM.from_pretrained(name).to(DEVICE).eval().half()
tokenizer = AutoTokenizer.from_pretrained(name)

prompt = "Quantum computers are"

inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        **params
    )

print(prompt)
print()
print(tokenizer.decode(outputs[0])[len(prompt):].strip())

output:


Quantum computers are

the holy grail of computing. They promise to solve problems that are intractable on today’s supercomputers, and they could be the key to solving some of the world’s most pressing problems.

But they’re not quite there yet.

Quantum computers are still in

DeepSpeed

import random
import os
import numpy as np
import torch

def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:2"
    os.environ["PL_GLOBAL_SEED"] = str(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

set_random_seed(42)

params = {
 'num_beams': 2,
 'do_sample': False,
 'max_new_tokens': 65,
 'use_cache': True,
 'no_repeat_ngram_size': 5,
 'num_return_sequences': 1}

from transformers import AutoTokenizer, AutoModelForCausalLM
import deepspeed

DEVICE = torch.device("cuda:0")
name = "EleutherAI/gpt-j-6B"
model = AutoModelForCausalLM.from_pretrained(name).to(DEVICE).eval().half()
tokenizer = AutoTokenizer.from_pretrained(name)

model = deepspeed.init_inference(
    model=model,     
    mp_size=1,       
    dtype=torch.float16, 
    replace_method="auto", 
    replace_with_kernel_inject=True,
)

prompt = "Quantum computers are"

inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
with torch.inference_mode():
    outputs = model.generate(
        **inputs,
        **params
    )

print(prompt)
print()
print(tokenizer.decode(outputs[0])[len(prompt):].strip())

output:



Quantum computers are

the holy grail of modern science. Physics World War II, but they’s.


Quantum computers are a holy grail of quantum computers are a holy

of modern.
quantum-computers are a holy gra
of.

and

are.

.

ds_report output

--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
 [WARNING]  please install triton==1.0.0 if you want to use sparse attention
sparse_attn ............ [NO] ....... [NO]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
utils .................. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
spatial_inference ...... [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/mikhail/venv/lib/python3.7/site-packages/torch']
torch version .................... 1.13.0+cu116
torch cuda version ............... 11.6
torch hip version ................ None
nvcc version ..................... 11.3
deepspeed install path ........... ['/home/mikhail/venv/lib/python3.7/site-packages/deepspeed']
deepspeed info ................... 0.7.6+a4ceabb6, a4ceabb6, master
deepspeed wheel compiled w. ...... torch 1.13, cuda 11.6

System info (please complete the following information):

  • OS: Debian GNU/Linux 10 (buster) (GNU/Linux 4.19.0-21-cloud-amd64 x86_64\n)
  • CUDA Version: 11.6
  • x1 A100 40Gb
  • DeepSpeed 0.7.6+a4ceabb6
  • Hugging Face Transformers 4.18.0
  • Python 3.7

zelcookie avatar Nov 14 '22 20:11 zelcookie

Hi @zelcookie, thanks for reporting this. I am able to reproduce with your scripts and will work on determining a root cause of this.

cmikeh2 avatar Nov 14 '22 21:11 cmikeh2

Hi @zelcookie, I've identified that the underlying reason for the low-quality outputs is our KV-cache implementation is currently incompatible with num_beams > 1. When using beam search, the KV-cache associated with each batch isn't static and may be reordered and rewritten as tokens are generated. Due to how the module injection policy works, it's not really possible to use the DeepSpeed KV-cache implementation with beam search.

I am going to work on generalizing the limited support for external KV-caches that we currently have to be compatible with the Hugging Face implementation, which should allow using DeepSpeed Inference alongside beam search. I don't currently have a concrete ETA of when that will be, but I will provide updates here on an ETA/progress as the timeline becomes more clear.

cmikeh2 avatar Nov 17 '22 01:11 cmikeh2

Hi @cmikeh2, thanks for information. Is it possible to use kernel-injection without KV-cache, for example use only Attention and Feed-Forward blocks from DeepSpeed and HuggingFace QKV block. Or in any other way?

zelcookie avatar Nov 17 '22 10:11 zelcookie

Hi, I have a small update on this issue. I locally lifted the commit ban on num_beams>1 in order to verify the malfunction. Used num_beams=3. And here's what happened: I had a problem on a 10G graphics card, which has about 23 gigabytes of memory, whereas GPT-J takes about 12-14 gigabytes in fp16. On NVIDIA A 100, it did not occur and the quality of the text was quite the same as that of pure pytorch. I decided to play around with the parameters a bit and noticed that reducing max_out_tokens from 1024 to 512 or lower makes generation with num_beams>1 quite meaningful. Below are some examples with prompt "Obama was born in Honolulu, Hawaii. After graduating from":

'Obama was born in Honolulu, Hawaii. After graduating from Punahou School and Occidental College, Obama went to Harvard Law School. He graduated in 1991.\n\nIn the late 1980s, Obama worked as a community organizer for ACORN (Association of Community Organizations for Reform Now) in Chicago. The organization is now under investigation by the Justice Department of Justice for voter registration fraud.\n\nObama’s family law school classmate Michelle Robinson has been with Obama when “Acadre” in Chicago where they lived together for two',
'Obama was born in Honolulu, Hawaii. After graduating from Occidental College of Physicians”s going away present from George W. Bush to go to the United States Military Academy at West Point. He graduated with honors and became a doctor.\n\nIn 2007). Obama attended Harvard Law School where he was editor-in-chief of the Harvard Law Review. While there he met Michelle Robinson who is also a graduate of Punaholism. They were married in 2008.\n\nAfter graduation Obama went to work for the Chicago law firm Sidley Austin',
"Obama was born in Honolulu, Hawaii. After graduating from Occidental College of Physicians & he went to work for the Harvard Law School at New York University. He passed the California Bar Exam and became a member of the State Bar of California. Obama moved to Chicago's law school (University of Chicago) in 1991.\n\nAfter graduation Obama settled in Chicago where he worked as a civil rights lawyer. In 2002-He joined the Illinois State Senate representing the South Side of Chicago. In 1998-2000 Obama served as chairman of the Illinois",

In general, the generation almost always gives a text of this quality, but there were several examples where something seemed to go wrong:

'Obama was born in Honolulu, Hawaii. After graduating from Punahou School, he attended Occidental College in Los Angeles, California, where he majored in political science and served aspired to become a writer. He then went on to earn a law degree from Harvard Law School. Obama is the first African-American president of the United States\nUniversity of Chicago | Barack Obama's Profile | Obamawaslidekennethuobamaprooksidentitynigopresidentbarackobamaobama_presidentob'
'Obama was born in Honolulu, Hawaii. After graduating from Punahou School and Occidentalized the University of Occidental College in 1979-1980 presidential campaign for Uchicago Chicago Law School, he won election to Illinois State Senate\n\n1981-1988) Presidentialsolyrhetired–2008) UnitedStatesmanSenatorIllinoisStateUnited States SenatorPresident Obam Barack Obama(Vice (1961st term as 44thirty-elect199344thpresident 199445thpresident 199546thpresident 199647thpresident 199748thpresident 1998'

I do not know, maybe it's just problems configuring the generation or dataset parameters, since the vast majority of texts created with num_beams=3 show the absence of those artifacts that were stated in the issue (repetitions, failures). However, when generating with max_out_tokens=1024, the text becomes generally incoherent for all examples of generation.

I hope this helps, and shouldn't the ban on num_beams>1 be lifted?

hivaze avatar Jan 20 '23 07:01 hivaze

Hi @hivaze, these outputs are very intriguing.

Fundamentally, it's possible that we produce reasonable outputs with num_beams>1, but in general it's sort of lucky if it does happen. Currently, DeepSpeed-Inference manages its own KV-cache so that we can do a 0-copy append during inference. The benefit of this is less overhead of making large data copies for the KV-cache. The issue that we encounter with beam search is that the KV-cache is not necessarily located in the same location for additional forward passes of the model. The following are input_ids as an input for a beam_search=3 with GPT-J-6B without DeepSpeed.

tensor([[ 2061,   318, 10766, 22785,    30],
        [ 2061,   318, 10766, 22785,    30],
        [ 2061,   318, 10766, 22785,    30]], device='cuda:0')
tensor([[ 2061,   318, 10766, 22785,    30, 29744],
        [ 2061,   318, 10766, 22785,    30,   632],
        [ 2061,   318, 10766, 22785,    30,  1867]], device='cuda:0')
tensor([[ 2061,   318, 10766, 22785,    30, 29744, 22785],
        [ 2061,   318, 10766, 22785,    30,  1867,   857],
        [ 2061,   318, 10766, 22785,    30,  1867,   318]], device='cuda:0')
tensor([[ 2061,   318, 10766, 22785,    30, 29744, 22785,   318],
        [ 2061,   318, 10766, 22785,    30,  1867,   857,   340],
        [ 2061,   318, 10766, 22785,    30,  1867,   318,   340]],
       device='cuda:0')
tensor([[ 2061,   318, 10766, 22785,    30, 29744, 22785,   318,   257],
        [ 2061,   318, 10766, 22785,    30,  1867,   857,   340,   466],
        [ 2061,   318, 10766, 22785,    30, 29744, 22785,   318,   281]],
       device='cuda:0')
tensor([[ 2061,   318, 10766, 22785,    30,  1867,   857,   340,   466,    30],
        [ 2061,   318, 10766, 22785,    30, 29744, 22785,   318,   257,  3859],
        [ 2061,   318, 10766, 22785,    30, 29744, 22785,   318,   257,  3992]],
       device='cuda:0')

For the first input, in DeepSpeed what would happen is we would initially populate 3 identical KV-caches. The first generative inference pass would concatenate 29744, 632, and 1867 onto each of these three caches. The issue emerges when the previous history of a row changes. For example, in the third forward pass we can see that the input for the second batch is [ 2061, 318, 10766, 22785, 30, 1867, 857]. However, the entry in the KV-cache for batch=2 in DeepSpeed is actually [ 2061, 318, 10766, 22785, 30, 632]. We are now concatenating the new token 857 onto the wrong KV-cache and attending against it. Over the course of this generation sequence, all of the KV caches are effectively corrupted in this manner.

It is possible to have a sequence in which the generation does not corrupt a KV-cache, which may be what you are observing above. Furthermore, if we don't inject kernels (i.e. just use DeepSpeed for injecting model parallelism across multiple GPUs), then we may still have some level of support since that should still use the HF cache.

cmikeh2 avatar Jan 20 '23 18:01 cmikeh2

It's sad of course, as I understand it, it can't be fixed quickly? By the way, can such a problem occur when using num_return_sequences > 1? Since GenerationMixin, as far as I know, also combines them into batches and sends them to the model, so cache is a batch as well.

hivaze avatar Jan 21 '23 11:01 hivaze

You also said that it is possible to disable the injection of kernels, and then is it possible to change replace_policy by excluding attention layers from it, will this fix the problem?

hivaze avatar Jan 21 '23 12:01 hivaze

Hi! It would be great if beam search works with DeepSpeed. I'm guessing it's probably the most common decoding algo. used in prod.

are other generation strategies supported too?

  1. contrastive search: https://github.com/huggingface/transformers/pull/19963
  2. eta / epsilon sampling: https://github.com/huggingface/transformers/pull/21121

tokestermw avatar Jan 27 '23 17:01 tokestermw

@tokestermw is Beam Search still broken with DeepSpeed? You asked about contrastive search, and due to the issue here https://github.com/huggingface/transformers/issues/21151 It seems that contrastive search is broken

mallorbc avatar Feb 09 '23 07:02 mallorbc

@mallorbc I've done some benchmarks using gpt2 with fp16 precision on my own data (of course ymmv).

System info

  • cuda version 11.7
  • A10G instance 24G
  • DeepSpeed 0.7.7
  • Transformers 4.25.1
  • Python 3.7
  • Torch 1.13.1

in summary, with and w/o DeepSpeed:

Top-P sampling (top_p = 0.6, temperature = 0.6)

  • Score ~1% degradation
  • Latency ~2x speedup

Beam Search (beam = 3)

  • Score ~14% degradation (w/ some poor generations mixed in)
  • Latency ~2.5x speedup

Contrastive search (top_k = 4, penalty_alpha = 0.6)

  • Score ~62% degradation
  • Latency ~2.8x speedup (partly due to shorter generations)

Eta sampling (eta_cutoff = 0.0005)

  • Score: 0.05% degradation
  • Latency: ~2.2x speedup

So top p and eta sampling work great. Beam search and contrastive search degrade significantly

tokestermw avatar Feb 09 '23 17:02 tokestermw

@tokestermw Thanks so much for sharing your insights! I assume to get these results you did something like a string compare for results generated with and without DeepSpeed?

mallorbc avatar Feb 09 '23 18:02 mallorbc

You also said that it is possible to disable the injection of kernels, and then is it possible to change replace_policy by excluding attention layers from it, will this fix the problem?

@hivaze hivazeHi! Have you tried to disable KV-cache kernel injection and did it solved the problem when using num_beams>1?

PanQiWei avatar Feb 17 '23 05:02 PanQiWei

@tokestermw Thanks so much for sharing your insights! I assume to get these results you did something like a string compare for results generated with and without DeepSpeed?

@mallorbc yes, but also QA'd the generated results

tokestermw avatar Feb 17 '23 06:02 tokestermw

@hivaze hivazeHi! Have you tried to disable KV-cache kernel injection and did it solved the problem when using num_beams>1?

I tried to understand the structure of this feature, unfortunately there is no documentation on it, besides, not so long ago there was a large code refactor in this area, which further confused me. All in all, I found it to be quite time consuming.

hivaze avatar Feb 23 '23 21:02 hivaze

@PanQiWei

Have you tried to disable KV-cache kernel injection and did it solved the problem when using num_beams>1?

Won't this remove the speed-up benefit of using DeepSpeed? I guess you still have tensor slicing, but without the speed improvement, the benefit is greatly diminished.

mallorbc avatar Feb 28 '23 20:02 mallorbc

@hivaze hivazeHi! Have you tried to disable KV-cache kernel injection and did it solved the problem when using num_beams>1?

I tried to understand the structure of this feature, unfortunately there is no documentation on it, besides, not so long ago there was a large code refactor in this area, which further confused me. All in all, I found it to be quite time consuming.

@hivaze did you have any success with disabling KV-cache kernel injection? If so, what config did you use with init_inference?

nickmitchko avatar Apr 03 '23 15:04 nickmitchko

NotImplementedError: DeepSpeed does not support 'num_beams' > 1, if this is important to you please add your request to: https://github.com/microsoft/DeepSpeed/issues/2506 - I really need it

YarrDOpanas avatar Apr 28 '23 10:04 YarrDOpanas

We need num_beams > 1 also to actually use DeepSpeed.

brevity2021 avatar May 03 '23 23:05 brevity2021

+1 to add num_beams >1

eggie5 avatar Jun 07 '23 11:06 eggie5

Is there any update or at least a way to disable KV-cache kernel injection? I dont need the speedup, but just deepspeed's ability to split my simple huggingface gpt2 model over several GPUs

PelzKo avatar Jul 18 '23 00:07 PelzKo

any update about num_beams > 1?

seanxcwang avatar Aug 11 '23 08:08 seanxcwang