text-generation-inference
text-generation-inference copied to clipboard
Add option to specify model datatype
Feature request
It would be great if we could specify the model's datatype as a command-line argument.
Motivation
For most PyTorch models, torch_dtype
is currently set to torch.float16
or torch.float32
, depending on whether or not CUDA is available (see code).
However, some models, such as mosaicml/mpt-30b-instruct
, specify a different datatype, such as torch.bfloat16
(see config). If we could load this model with its intended datatype, we should be able to deploy it on a single A100 40GB GPU. However, trying to load it with torch.float16
as text-generation-inference currently does causes us to run out of GPU memory.
Your contribution
I'm happy to write a PR for this if others would find this feature useful.
+1 could you also add an option to disable conversion to safetensor
My custom gpt-neox model (trained with bf16) degrades heavily in performance when loading with fp16. Would also love this option.
Created a PR for it.