Liger-Kernel
Liger-Kernel copied to clipboard
inference qwen2 model ,The reasoning is garbled and ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
trafficstars
🐛 Describe the bug
when I load model with AutoLigerKernelForCausalLM ,I get ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
when load mdoel Apply Model-Specific Patching APIs ,,The reasoning is garbled
Reproduce
import os
os.environ['CUDA_VISIBLE_DEVICES']='0,1'
import torch
from transformers import AutoTokenizer,AutoModelForCausalLM
from liger_kernel.transformers import AutoLigerKernelForCausalLM
model_path="Qwen/Qwen2-7B-Instruct"
tokenizer=AutoTokenizer.from_pretrained(model_path)
model=AutoModelForCausalLM.from_pretrained(model_path,trust_remote_code=True,device_map='cuda:0')
from liger_kernel.transformers import apply_liger_kernel_to_qwen2
apply_liger_kernel_to_qwen2()
apply_liger_kernel_to_qwen2(
rope=True,
swiglu=True,
cross_entropy=True,
fused_linear_cross_entropy=False,
rms_norm=False,
model=model,
)
Liger_model=AutoLigerKernelForCausalLM.from_pretrained(model_path,trust_remote_code=True,device_map='cuda:1')
def generate(model,model_inputs,config):
# device=next(model.model.parameters()).device
# model_inputs.to(device)
# Generate
with torch.cuda.amp.autocast():
generated_ids = model.generate(
model_inputs.input_ids,
min_new_tokens=config['min_new_tokens'],
max_new_tokens=config['max_new_tokens'],
# pad_token_id=tokenizer.eos_token_id
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(len(generated_ids[0]))
print(response)
prompt = "Hey, are you conscious? Can you talk to me?"
model_inputs = tokenizer(prompt, return_tensors="pt").to('cuda:0')
config={
'min_new_tokens':512,
'max_new_tokens':512
}
generate(Liger_model,model_inputs,config)
# error with cpu and gpu
generate(model,model_inputs,config)
# error with
### Versions
Environment Report:
-------------------
Operating System: Linux-5.15.0-52-generic-x86_64-with-glibc2.35
Python version: 3.10.14
PyTorch version: 2.3.1
CUDA version: 12.1
Triton version: 3.0.0
Transformers version: 4.43.4