transformers
transformers copied to clipboard
[`bnb`] Let's make serialization of int8 models possible
What does this PR do?
Before this PR, it was not possible to save an 8bit model, or load an 8bit model from the Hub. This PR makes this feature possible. If this PR gets merged, users can upload 8bit models on the Hub and/or load 8bit models from the Hub, hence save 2x memory compared to half-precision models.
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", device_map="auto", load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
text = "Hello my name is"
inputs = tokenizer(text, return_tensors="pt").to(0)
outputs = model.generate(**inputs)
print(tokenizer.decode(outputs[0]))
>>> Hello my name is Nate, I am a professional photographer and I am a member of the
model.save_pretrained("./saved_int8")
model = AutoModelForCausalLM.from_pretrained("./saved_int8", device_map="auto", load_in_8bit=True)
outputs = model.generate(**inputs)
print(tokenizer.decode(outputs[0]))
>>> Hello my name is Nate, I am a professional photographer and I am a member of the
Depends on https://github.com/TimDettmers/bitsandbytes/pull/159
Let's put it as draft before I address the last TODOs and open questions & before https://github.com/TimDettmers/bitsandbytes/pull/159 gets merged.
TODOs and open questions:
- ability to push
BitsAndBytesConfig - Do we want to save the serialized model under the name
pytorch_model.bin? I would say yes for simplicity reasons but we need to make sure that a user callsfrom_pretrainedwithload_in_8bit, hence add a warning if there is aquantization_config.jsonon the Hub repo + the user is not passingload_in_8bit=True. - Force
load_in_8bit=Trueif there is aquantization_config.jsonon the Hub repo? - Update docs
- Update warnings
- Safety checkers for
bnbversions - Add a test to check if it works using sharded fp16 weights
cc @sgugger I left few open questions, would love to hear your thoughts on these!
The documentation is not available anymore as the PR was closed or merged.
The design is not easy enough to use. If a user saves a quantized model and pushes to the Hub, it should work directly with from_pretrained. This is why I insisted that the quantization config should be saved inside the model config. This way you won't need to have the user pass load_in_8_bit=True, as you can read it from the config.
awesome ok, I'll work on that, so if there is a quantized config on the repo we should force-use device_map=auto & load_in_8bit in this case
The PR is ready for review @sgugger ! This PR is not mergeable before the bnb release of course
Thanks for the heads up! :D It should be much better now! For me the PR is ready for a review now