tensorrtllm_backend
tensorrtllm_backend copied to clipboard
problem: lora_weights data type
I build my model with bfloat16 for lora data type. But in: https://github.com/triton-inference-server/tensorrtllm_backend/blob/v0.16.0/all_models/inflight_batcher_llm/tensorrt_llm/1/model.py lora_weights is float16.
I have to change line 271 in this file:
if lora_weights is not None:
kwargs["weights"] = from_numpy(lora_weights).squeeze()
to:
if lora_weights is not None:
kwargs["weights"] = from_numpy(lora_weights).squeeze().to(torch.bfloat16)
Can you change these code that work for both float16 and bfloat16