llama3
llama3 copied to clipboard
How do models do batch inferring when using the transformer method?
I am a noob. Here is my code, how can I modify it to do batch inferring?
def load_model(): model_id = 'llama3/Meta-Llama-3-70B-Instruct' pipeline = transformers.pipeline( "text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto", # return tokenizer, pipeline return pipeline
def get_response(pipeline, system_prompt, user_prompt): messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ]
prompt = pipeline.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
terminators = [
pipeline.tokenizer.eos_token_id,
pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
outputs = pipeline(
prompt,
max_new_tokens=4096,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.9,
)
@ArthurZucker would the list of messages make it?
It doesn't seem to work.
Reasons:
- Inference time is the same as a single inference,
- console warnings appear one by one, it can be inferred that the model is read one by one
Here is the code for batch inference:
def load_model():
model_id = '/home/pengwj/programs/llama3/Meta-Llama-3-70B-Instruct'
# tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
pipeline = transformers.pipeline(
"text-generation",
model=model_id,
model_kwargs={"torch_dtype": torch.bfloat16},
device_map="auto",
)
return pipeline
#batch_system_prompt = [[],[],[],[]] ; sections = = [[],[],[],[]]
messages = [[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}] for system_prompt, user_prompt in zip(batch_system_prompt, sections)]
# prompt = [[],[],[],[]]
prompt = pipeline.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
terminators = [
pipeline.tokenizer.eos_token_id,
pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
outputs = pipeline(
prompt,
max_new_tokens=1024,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.9,
)