lit-llama icon indicating copy to clipboard operation
lit-llama copied to clipboard

Issue with "kv_cache" while using modified generate/lora.py for a list of inputs

Open mriganktiwari opened this issue 2 years ago • 5 comments

Dimensions of (k,v) are getting changed due to kv_cache somewhere between line 195-228 in file model.py.

This is happening when line 65 is called by this generate function call

This generate function call I've put under a for loop to iterate over a list of inputs for generation using LoRA weights.

Am trying to find how can I reset this kv_cache, as soon as new input is taken. Any help is appreciated.

More error logs:

Traceback (most recent call last):
  File "generate/lora_itn.py", line 130, in <module>
    CLI(main)
  File "/opt/conda/lib/python3.8/site-packages/jsonargparse/_cli.py", line 85, in CLI
    return _run_component(component, cfg_init)
  File "/opt/conda/lib/python3.8/site-packages/jsonargparse/_cli.py", line 147, in _run_component
    return component(**cfg)
  File "generate/lora_itn.py", line 96, in main
    output = generate(model,
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/jovyan/mrigank-llm-datavol-1/lit-llama/generate.py", line 71, in generate
    logits = model(x, max_seq_length, input_pos)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/lightning/fabric/wrappers.py", line 116, in forward
    output = self._forward_module(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jovyan/mrigank-llm-datavol-1/lit-llama/lit_llama/model.py", line 114, in forward
    x, self.kv_caches[i] = block(x, rope, mask, max_seq_length, input_pos, self.kv_caches[i])
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jovyan/mrigank-llm-datavol-1/lit-llama/lit_llama/model.py", line 163, in forward
    h, new_kv_cache = self.attn(self.rms_1(x), rope, mask, max_seq_length, input_pos, kv_cache)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jovyan/mrigank-llm-datavol-1/lit-llama/lit_llama/model.py", line 230, in forward
    y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
RuntimeError: The size of tensor a (318) must match the size of tensor b (314) at non-singleton dimension 3

file lora_itn.py is a modified lora.py with changes for taking in a list of inputs instead.

mriganktiwari avatar Jul 07 '23 10:07 mriganktiwari

You can reset_cache after generation. Lit-GPT does it: https://github.com/Lightning-AI/lit-gpt/blob/main/generate/base.py#L180

carmocca avatar Jul 12 '23 14:07 carmocca

I got into the same issue too. As @carmocca mentioned, I also solved it using reset_cache().

With that being said, I don't fully understand the purpose and the usage of this cache mechanism. Could someone please explain it briefly or refer to an updated source? Why do we need that? And why do we need to reset it when doing an inference for two consecutive different inputs?

Thanks!

SagiPolaczek avatar Jul 13 '23 18:07 SagiPolaczek

You can read about the KV cache here: https://kipp.ly/transformer-inference-arithmetic/

It depends on the sequence length, so if it changes it needs to be reset.

When you do inference with a prompt for a maximum number of tokens, since we know what the maximum sequence length will be, we set it to that value. But if the next generation length is different, it needs to be reset. You could also set it to a value that is equal to the longest of your generation, but it would be wasteful for the generations that are smaller than that.

carmocca avatar Jul 13 '23 18:07 carmocca

@carmocca thanks for the quick reply and your awesome work! ⚡

SagiPolaczek avatar Jul 13 '23 18:07 SagiPolaczek

When conducting generation for multiple consecutive inputs on a LoRA fine-tuned LLaMA, I noticed that using 'reset_cache' after each generation for one input will affect the quality of generation on the next input. However, if you load the model again after each generation, the performance stays good. But reloading consumes lots of time. Could you help provide some explanation why 'reset_cache' will decrease the performance of the generation on the next consecutive inputs?

Code: I modified the code of 'generate/lora.py' to enable consecutive generation on multiple inputs. Basically, just add a for loop and model.reset_cache()

  # support multiple inference 
  outputs = []
  num_samples = len(input)
  for i in range(num_samples):
      sample = {"instruction": prompt[i], "input": input[i]}
      prompt = generate_prompt(sample)
      encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)

      t0 = time.perf_counter()
      output = generate(
          model,
          idx=encoded,
          max_new_tokens=max_new_tokens,
          temperature=temperature,
          top_k=top_k,
          eos_id=tokenizer.eos_id
      )
      t = time.perf_counter() - t0

      model.reset_cache()
      output = tokenizer.decode(output)
      output = output.split("### Response:")[1].strip()
      print(output)
      print(f"Time for inference: {t:.02f} sec total, {max_new_tokens / t:.02f} tokens/sec", file=sys.stderr)
      outputs.append(output)

HenryPengZou avatar Aug 07 '23 06:08 HenryPengZou