long_llama
long_llama copied to clipboard
0-shot long-context summarization / QA inference
Hi,
Thank you for this great effort. I am trying to use your 3B m-instruct-v1_1 model to evaluate on my custom long-context QA dataset with context length up to 200k.
I have a question. I find it difficult to locate keywords like 256k in your colab / .py examples. There are several mentions of 1024 , 2048.. as normal llama has. So this model does support long context right? in which case I should not be using the "drop-in" replacement example.
Thank you very much.
Thank you for your question!
Yes, this Colab contains a demo where the model loads our paper and is asked questions about it. The paper is far longer than 2K tokens. Whereas, this Colab contains the passkey retrieval demo for the base (non-instruction tuned model). The models in Colab are loaded in a way that allows the use of long context functionality (trust_remote_code
option).
Here is a brief explanation of how LongLLaMA handles long inputs and why you do not see any fixed context limitation.
I would not expect a striking performance from the 3B model. Also, the choice of the last_context_length
parameter can have a significant impact on the result (in general for the 3B model a good choice should be the 2048-expected length of the output).
Thank you. I'm taking a look and will ask follow-up questions.
As I understood the Colab QA demo (using TextStreamer) should also work in a standard HF pipeline right? I'm trying to do inference on a .jsonl dataset. According to your README:
Loading model
import torch
from transformers import LlamaTokenizer, AutoModelForCausalLM
tokenizer = LlamaTokenizer.from_pretrained("syzymon/long_llama_3b_v1_1")
model = AutoModelForCausalLM.from_pretrained("syzymon/long_llama_3b_v1_1",
torch_dtype=torch.float32,
trust_remote_code=True)
Input handling and generation
LongLLaMA uses the Hugging Face interface, the long input given to the model will be split into context windows and loaded into the memory cache.
prompt = "My name is Julien and I like to"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
outputs = model(input_ids=input_ids)
During the model call, one can provide the parameter last_context_length
(default $1024$), which specifies the number of tokens left in the last context window. Tuning this parameter can improve generation as the first layers do not have access to memory. See details in [How LongLLaMA handles long inputs](#How-LongLLaMA-handles-long-inputs).
generation_output = model.generate(
input_ids=input_ids,
max_new_tokens=256,
num_beams=1,
last_context_length=1792,
do_sample=True,
temperature=1.0,
)
print(tokenizer.decode(generation_output[0]))
So I am doing something like this:
def generate_answers(model, tokenizer, data, output_path):
generations = []
for item in data:
input_text = item['input']
question = f"\nAnswer the following question briefly using information from the text above.\nQuestion: {item['output']}\nAnswer: "
# construct the prompt
prompt = f"{input_text}{question}"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
generation_output = model.generate(
input_ids=input_ids,
max_new_tokens=256,
num_beams=4,
last_context_length=1792,
temperature=0.7,
top_p=1.0
)
answer = tokenizer.decode(generation_output[0], skip_special_tokens=True)
# store the generated answer
generations.append({
"id": item["id"],
"input": input_text,
"output": answer
})
# save generations
save_generations(output_path, generations)
```
Does this look good or I misunderstood?
Thank you!
btw Can I ask what's a good sampling strategy if you've tried some?
I'm doing:
model = AutoModelForCausalLM.from_pretrained( MODEL_PATH, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, mem_attention_grouping=(1, 2048), use_cache=True, ) output = model.generate( input_ids=input_ids, max_new_tokens=512, temperature=0.7, top_p=0.9, last_context_length=1536, num_beams=2, do_sample=True, )
And currently, the generation quality is very poor. Thank you very much for the help! btw in the readme you provided an example with {num_beams=1, do_sample=True}, and I thought this will be greedy decoding anyway?