litgpt icon indicating copy to clipboard operation
litgpt copied to clipboard

Generating batch outputs?

Open ron-vnai opened this issue 2 years ago • 10 comments

Hi, I wonder if it is possible to generate responses over a batch of input queries?

The scenario I consider is having a batch of different input queries (say all have the same number of tokens), and generating output responses for the whole batch in an efficient manner.

Naturally, I can modify the "generate/base.py" generate function to support batch inputs naively. However, I look for efficient implementation.

Thanks!

ron-vnai avatar Aug 01 '23 15:08 ron-vnai

Sorry, this is not implemented at the moment for simplicity in understanding the generation code, (it's inherited from nanoGPT)

carmocca avatar Aug 14 '23 12:08 carmocca

I think it's supported (by model.py). Try left padding the input sequences across each sequence in the batch. e.g. idx [B, T], where T is the longest length across all sequences, and the smaller ones get padded up.

You should be able to pad with tokenizer.bos_id. I think.

yacineMTB avatar Sep 02 '23 04:09 yacineMTB

Try left padding the input sequences across each sequence in the batch.

just remember to implement an attention mask so the model ignores the padding tokens or your outputs will be conditioned on the padding tokens which may give you strange outputs — idt using the bos token would remedy that, or at least would not be identical to running inference without padding

MF-FOOM avatar Sep 02 '23 04:09 MF-FOOM

also wait wouldn't left padding screw up the positional encoding, even if you attention mask?

MF-FOOM avatar Sep 02 '23 04:09 MF-FOOM

Ah, cheers!

Attention masking should be straightforward, two things need to be done:

  • change build_mask_cache to mask based on the padding input: (B, T, block_size, block_size). Should carry through
  • Use a large negative number instead of a boolean mask on scaled_dot_product_attention, to work around a bug on pytorch's implementation. (see https://github.com/pytorch/pytorch/issues/103749)

I think you're right about the left padding needing to be accounted for in the positional encoding. After adding the mask, things got thrown off - which makes sense. I think that this is being handled here? And, also, here, on huggingface's llama impl

yacineMTB avatar Sep 03 '23 02:09 yacineMTB

not shilling, but this does batching :)

https://github.com/yacineMTB/just-large-models/blob/master/llama.py

yacineMTB avatar Sep 07 '23 01:09 yacineMTB

I'm not 100% familiar with the advantages of left vs right so if one of you has a good resource on this, I'd appreciate it if you could share it

carmocca avatar Sep 28 '23 18:09 carmocca

From what I understand, right padding will not require creating an attention mask (so you can keep using flash attention), but then one cannot simply -1 here: https://github.com/Lightning-AI/lit-gpt/blob/main/generate/base.py#L62

carmocca avatar Sep 28 '23 19:09 carmocca

Another advantage of right-padding is that it is (more) compatible with models that have an absolute positional embedding

rasbt avatar Sep 28 '23 19:09 rasbt

Hello there, I have implemented a very naive version of batched inference compatible to lit_gpt. The most critical problem here is that it does not support kv_cache . I inspected the KV cache code in lit_gpt/base.py, but had no idea how to deal with variable length prompts neatly since cos and sin are 1D-tensors. Maybe expanding this into 2D could be a solution...

Paste this gist (generate_batch()) into generate/base.pt : https://gist.github.com/jinulee-v/d7dcfe9ad1280acc69c279fcbfcdfd22

  • Can : generate from variable-length prompts right-padded with eos_id
encoded = pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.eos_id).to(device)
y = generate_batch(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
  • Can't : use kv_cache, therefore very slow and requires further optimization. (13B model, single A100(80GB), batch_size=16: ~12s/step (~1.3tok/s))

jinulee-v avatar Nov 29 '23 02:11 jinulee-v