DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

[BUG] DeepSpeed Inference with GPT-J using batches with padding gives wrong outputs

Open tomerip opened this issue 2 years ago • 25 comments

Describe the bug Using DeepSpeed Inference (using deepspeed.init_inference) gives weird outputs when using batch size > 1 and padding the inputs.

I'll first state the problem with more detail and then explain what I tried in order to narrow it down.

The problem: I'm trying to run inference with GPT-J (EleutherAI/gpt-j-6B) on a very large dataset and therefore want to achieve the highest throughput possible for my setup. I'm using a p3.16xlarge instance with 8 V100 GPUs so I can in theory fit a batch size of more than 1, since DeepSpeed helps sharing the tensors across the GPUs. Since the inputs are of different length, I have to use padding. This is how I pad (let's assume batch_size=4 so len(input_texts) = 4):

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
tokenized_inputs = tokenizer(
            list(input_texts), 
            return_tensors='pt', 
            padding=True,
            max_length=tokenizer.model_max_length - args.max_new_tokens, 
            truncation=True,
        )

Now to the problem, assuming these are sequence lengths of each input (in number of tokens of course): idx0, 1452 idx1, 1588 idx2, 1055 idx3, 650

The outputs I get from the model will be exactly what I expect for idx1 (since it's the longest and has no padding), very close close to what I expect for idx0, but terrible for idx3. What I "expect" is either when I run the exact same code with DeepSpeed with batch_size=1 or when I run the same code without DeepSpeed on CPU with batch_size=4. On both of these cases (DeepSpeed bsz=1 and CPU bsz=4) the outputs are identical, and they also make sense (it's an extraction task so I can tell whether it makes sense or not).

I tried figuring out what exactly causes this problem and based on the evidence I've gathered I think that the sequences that have a long padding on the left side somehow accumulate a huge attention weight that is not correctly masked by the attention mask. My evidence is:

  1. If I run with DeepSpeed bsz=4 and with torch.float16, the outputs I get are: !!!!!!!!!!!!!!!!!!! (no matter the prompt). But if I run it with torch.float32 I get "normal" outputs, but as I said, they differ from what I expect (defined above). So this makes me think some tensor overflows with f16 but not with f32. Also I should mention that running with DeepSpeed with fp16 and bsz=1 works perfectly.
  2. The longest input in the batch (which has no padding at all) gives the expected result. Those that are close to it in length have only a slightly weird output (small amount of padding tokens). Those that are much shorter (many padding tokens) have highly unrelated output.
  3. They way the GPT-J attention mechanism works (at least the HuggingFace implementation) is that you add -10,000 to the attention weight where the attention mask is 0. This might not be enough if the many padding tokens accumulate a large attention weight. Although when I run it on CPU with the HuggingFace implementation everything is ok so it might not be the reason.
  4. I'm pretty sure that the culprit is this function: https://github.com/microsoft/DeepSpeed/blob/a10e4811fe78b707289132c9695bade4715fe59b/csrc/transformer/inference/csrc/softmax.cu#L203 But unfortunately I don't speak CUDA so it's very hard for me to follow and point exactly what the problem is. For all I know the HuggingFace implementation of attention works (https://github.com/huggingface/transformers/blob/2c2a31ffbcfe03339b1721348781aac4fc05bc5e/src/transformers/models/gptj/modeling_gptj.py#L72).

To Reproduce

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.model_max_length = 2048

local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))
device = torch.device(f'cuda:{local_rank}')

model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")
model.config.pad_token_id = model.config.eos_token_id

model = deepspeed.init_inference(
    model,
    mp_size=world_size,
    dtype=torch.float32,
    replace_method='auto',
    replace_with_kernel_inject=True,
)
model.device = device

tokenized_inputs = tokenizer(
    list(input_texts), 
    return_tensors='pt', 
    padding=True,
    max_length=tokenizer.model_max_length - args.max_new_tokens, 
    truncation=True,
).to(device)

with torch.inference_mode():
    batch_output_tokens = model.generate(
        input_ids=tokenized_inputs['input_ids'],
        attention_mask=tokenized_inputs['attention_mask'],
        do_sample=False,
        max_new_tokens=args.max_new_tokens,
        min_length=tokenized_inputs.input_ids.shape[1]+args.max_new_tokens,
        repetition_penalty=1.1,
        pad_token_id=tokenizer.eos_token_id,
    )

batch_output_text = tokenizer.batch_decode(batch_output_tokens, skip_special_tokens=True)

Expected behavior Running DeepSpeed with batch_size=1 or batch_size=4 (or larger) should give the same outputs. Running DeepSpeed with fp16 and batch size > 1 should work and not give: "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!".

ds_report output

--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
sparse_attn ............ [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
async_io ............... [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
transformer_inference .. [YES] ...... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/opt/conda/lib/python3.6/site-packages/torch']
torch version .................... 1.10.2+cu111
torch cuda version ............... 11.1
nvcc version ..................... 11.1
deepspeed install path ........... ['/opt/conda/lib/python3.6/site-packages/deepspeed']
deepspeed info ................... 0.6.0+2151c78, 2151c78, master
deepspeed wheel compiled w. ...... torch 1.10, cuda 11.1

System info (please complete the following information): SageMaker instance p3.16xlarge with SageMaker container 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:1.8.1-gpu-py36-cu111-ubuntu18.04

Launcher context Launching with deepspeed --num_gpus 8 run_inference.py

Docker context See above.

tomerip avatar Feb 27 '22 13:02 tomerip

Hi @tomerip

Thanks for bringing this interesting issue. I will definitely look into this and fix it soon.

Reza

RezaYazdaniAminabadi avatar Mar 01 '22 04:03 RezaYazdaniAminabadi

Hi @RezaYazdaniAminabadi, I wanted to give an update with additional information: I tested the new v0.6.0 version and in fact now running even with batch_size=1, if using fp16 the output is "!!!!!!!!!!!" (where in v0.5.10 it works properly).

tomerip avatar Mar 13 '22 11:03 tomerip

Hey @RezaYazdaniAminabadi, Was curious if you've had time to look into this issue. Also, please let me know if I can help somehow (although if it's indeed from CUDA files, I'm afraid I can't)

tomerip avatar Mar 30 '22 14:03 tomerip

Hey @tomerip,

Sorry for the long delay here. We have a deadline by the end of the week, and I can get more time on this issue next week. Hopefully, this should not take too much of the time. Thanks, Reza

RezaYazdaniAminabadi avatar Mar 30 '22 16:03 RezaYazdaniAminabadi

@RezaYazdaniAminabadi Hi! I met the same issue at GPT model which take the padded input_ids. Is there any update about this issue?

codertimo avatar Apr 26 '22 06:04 codertimo

Just wanted to update, issue still exists in v0.6.5

tomerip avatar Jun 29 '22 11:06 tomerip

Hey @tomerip - were you able to find a workaround? I am experiencing the same problem with gpt-models.

trianxy avatar Aug 19 '22 16:08 trianxy

Hi guys,

Sorry for my delay here! @codertimo Yes, you are right that the padding is not handled correctly for this model at softmax kernel. This has been fixed very recently for BLOOM model and I am gonna work on fixing it for the rest of models too. I am going to focus on this more and send a PR with a fix soon. Thanks, Reza

RezaYazdaniAminabadi avatar Aug 19 '22 16:08 RezaYazdaniAminabadi

Great, thanks @RezaYazdaniAminabadi

trianxy avatar Aug 19 '22 18:08 trianxy

Hi @RezaYazdaniAminabadi , just checking whether you had the chance to work on that PR so far?

trianxy avatar Sep 08 '22 17:09 trianxy

Happy to help with testing any potential fixes!

If it will still take some time, then it would be great if there is a link with Bloom's fix, so that we can create a fix ourselves.

trianxy avatar Sep 08 '22 17:09 trianxy

@arashb is taking a look at it right now. @arashb , can you please comment on this a bit? Thanks

RezaYazdaniAminabadi avatar Sep 08 '22 17:09 RezaYazdaniAminabadi

encountered similar issue for RobertaModel

shuyingsunshine21 avatar Sep 09 '22 03:09 shuyingsunshine21

Hi @trianxy

Can you please try this PR to see if the issue is resolved? Thanks, Reza

RezaYazdaniAminabadi avatar Sep 10 '22 18:09 RezaYazdaniAminabadi

Hi @codertimo @tomerip @trianxy ,

I think this issue should be fixed with the same PR I mentioned above. I am seeing the same output between HF and DeepSpeed using the batched input. Could you please check if this is solved on your side? (Note: I have tested this with FP16 and I am seeing different outputs when using FP32, I am looking to solve this too).

HUGGINGFACE: ['DeepSpeed is a new company that has been working on the development of a new type of submarine. The DeepSea Challenger, as it’s called, will be able to dive down to depths of up to 3,000 meters (10,000 feet).', 'Today ia a nice day for a walk. I have been thinking about the weather and how it is going to be in the next few days. It has been raining all week, but today it is sunny and warm.\n\nI am not sure if you know this, but', "Never mind the fact that the only thing that could possibly make this a better movie is if it was set in the future, and not the past.\n\nThe problem with this film is that it's just too damn long. It's like they tried"]
[2022-09-12 02:46:51,829] [INFO] [logging.py:68:log_dist] [Rank -1] DeepSpeed info: version=0.7.3+aafba00c, git-hash=aafba00c, git-branch=cholmes/fix-long-seq-len-inference
[2022-09-12 02:46:51,829] [INFO] [logging.py:68:log_dist] [Rank -1] quantize_bits = 8 mlp_extra_grouping = False, quantize_groups = 1
Using /home/reyazda/.cache/torch_extensions/py38_cu113 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/reyazda/.cache/torch_extensions/py38_cu113/transformer_inference/build.ninja...
Building extension module transformer_inference...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module transformer_inference...
Time to load transformer_inference op: 0.510991096496582 seconds
[2022-09-12 02:46:53,862] [INFO] [logging.py:68:log_dist] [Rank -1] DeepSpeed-Inference config: {'layer_id': 0, 'hidden_size': 4096, 'intermediate_size': 16384, 'heads': 16, 'num_hidden_layers': -1, 'fp16': True, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-05, 'mp_size': 1, 'q_int8': False, 'scale_attention': True, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'rotary_dim': 64, 'rotate_half': False, 'rotate_every_two': True, 'return_tuple': True, 'mlp_after_attn': False, 'mlp_act_func_type': <ActivationFuncType.GELU: 1>, 'specialized_mode': False, 'training_mp_size': 1, 'bigscience_bloom': False}
DEEPSPEED: ['DeepSpeed is a new company that has been working on the development of a new type of submarine. The DeepSea Challenger, as it’s called, will be able to dive down to depths of up to 3,000 meters (10,000 feet).', 'Today ia a nice day for a walk. I have been thinking about the weather and how it is going to be in the next few days. It has been raining all week, but today it is sunny and warm.\n\nI am not sure if you know this, but', "Never mind the fact that the only thing that could possibly make this a better movie is if it was set in the future, and not the past.\n\nThe problem with this film is that it's just too damn long. It's like they tried"]

Thanks, Reza

RezaYazdaniAminabadi avatar Sep 11 '22 21:09 RezaYazdaniAminabadi

Thanks @RezaYazdaniAminabadi for fixing this!

Commit 4abd455521965930d0e921de8afc0073ea7df9d1 from the PR you mentioned fixes the problem when I tested it using a Huggingface gpt2 model. By the way: The commit aafba00c81eaf29c0c2b209a94bc31f4de942936 before that still had the bug.

I wasn't able to test the PR on longer input sequences, though. The model seems to produce wrong/non-determenistic outputs there due to https://github.com/microsoft/DeepSpeed/issues/2243 . You mentioned that you might have a fix for that issue, too. Once you merge the fix to the latter issue, I will go ahead and test it also on the longer input sequences.

trianxy avatar Sep 12 '22 21:09 trianxy

Hi @RezaYazdaniAminabadi, Is the issue now closed due to v0.7.3 being released?

tomerip avatar Sep 28 '22 13:09 tomerip

IIUC the summary is that https://github.com/microsoft/DeepSpeed/commit/4abd455521965930d0e921de8afc0073ea7df9d1 fixes the bug, but https://github.com/microsoft/DeepSpeed/pull/2212 merged after the release of 0.7.3, so it's not fixed on a released version yet.

pqn avatar Sep 29 '22 18:09 pqn

Hey @RezaYazdaniAminabadi, I just tested v0.7.4 and unfortunately I still can't properly run the input mentioned in this issue (although now due to probably a different bug introduced in v0.7.4, getting CUDA OOM same as this issue's latest comments).

To summarize the current state of inferring long prompts with bsz larger than 1 with HF GPT-J and DeepSpeed: v0.5.10: fp16 and bsz=1 - works perfectly v0.5.10: fp16 and bsz>1 - getting "!!!!!!!!!!!" as output v0.6.0 - v0.7.3: fp16 and bsz=1 - getting "!!!!!!!!!!" as output v0.7.4: fp16 and bsz=1 - getting CUDA OOM

Please consider reopening this issue until the problem is resolved.

Also, if there are any specific commits/PRs I can test or any other way I can help, let me know.

tomerip avatar Oct 25 '22 13:10 tomerip

Thanks @tomerip for looking back into this. I think this does appear to be the same underlying issue as https://github.com/microsoft/DeepSpeed/issues/2357. A fix for this will likely come from https://github.com/microsoft/DeepSpeed/pull/2433, but we are still seeing this issue on that PR currently. When we have an updated PR to test, I'll update here.

cmikeh2 avatar Oct 25 '22 14:10 cmikeh2

Hi @tomerip,

Can you please see if this PR fixes the issue? Thanks, Reza

RezaYazdaniAminabadi avatar Nov 02 '22 16:11 RezaYazdaniAminabadi

Hey @RezaYazdaniAminabadi, Sorry for the delay, just tested PR and it still outputs junk similar to v0.6.0-v0.7.3 (actually getting "." instead of "!").

tomerip avatar Nov 13 '22 14:11 tomerip

Hi @RezaYazdaniAminabadi, are there any expectations regarding a fix? Thanks!

tomerip avatar Dec 07 '22 12:12 tomerip

I am facing a similar issue with OPT models. Is there a fix for it?

prabin525 avatar Feb 09 '23 23:02 prabin525

I'm also having an issue with OPT models, specifically Galactica. I will try to make an example for that

allanj avatar May 05 '23 04:05 allanj