parallelformers
parallelformers copied to clipboard
INT8 support
Describe a requested feature
I wonder if there's any plan to support 8bit inference in parallelformers. Right now, we can load 🤗 transformers models in 8bit like here, e.g.:
model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True)
However, it's not possible to parallelize() the model with parallelformers since only fp16 mode is supported at the moment.
Expected behavior
If 8bit inference could be supported, it would good to add another argument as for fp16, e.g.
from parallelformers import parallelize
model = AutoModelForCausalLM.from_pretrained(model_name)
parallelize(model, num_gpus=2, int8=True, verbose='detail')
# or one argument for precision mode, where dtype can be either "int8", "fp16", or "fp32" (default)
# parallelize(model, num_gpus=2, dtype='int8', verbose='detail')