ipex-llm
ipex-llm copied to clipboard
Spec batch based on Xiangyu's commit
Description
Spec batch based on Xiangyu's commit. PVC: only support llama now
2. User API changes
Llama:
tokenizer = LlamaTokenizer.from_pretrained(model_path, padding_side='left')
tokenizer.pad_token_id = 0
with torch.inference_mode():
input_ids = tokenizer([input1, input2, ....], return_tensors='pt', padding=True).input_ids.to(model.device)
input_ids_length = input_ids.shape[-1]
print(f"input_id_len: {input_ids_length}")
output = model.generate(input_ids,
max_new_tokens=args.n_predict,
th_stop_draft=args.th_stop_draft,
do_sample=False)
output_str = tokenizer.batch_decode(output, skip_special_tokens=True)
3. Summary of the change
4. How to test?
- [x] local test llama on PVC, transformers 4.31
- [x] local test llama on PVC, transformers 4.36