transformers icon indicating copy to clipboard operation
transformers copied to clipboard

[`bnb`] Let's make serialization of int8 models possible

Open younesbelkada opened this issue 2 years ago • 3 comments

What does this PR do?

Before this PR, it was not possible to save an 8bit model, or load an 8bit model from the Hub. This PR makes this feature possible. If this PR gets merged, users can upload 8bit models on the Hub and/or load 8bit models from the Hub, hence save 2x memory compared to half-precision models.

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", device_map="auto", load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")

text = "Hello my name is"
inputs = tokenizer(text, return_tensors="pt").to(0)
outputs = model.generate(**inputs)
print(tokenizer.decode(outputs[0]))

>>> Hello my name is Nate, I am a professional photographer and I am a member of the
model.save_pretrained("./saved_int8")

model = AutoModelForCausalLM.from_pretrained("./saved_int8", device_map="auto", load_in_8bit=True)

outputs = model.generate(**inputs)
print(tokenizer.decode(outputs[0]))
>>> Hello my name is Nate, I am a professional photographer and I am a member of the

Depends on https://github.com/TimDettmers/bitsandbytes/pull/159

Let's put it as draft before I address the last TODOs and open questions & before https://github.com/TimDettmers/bitsandbytes/pull/159 gets merged.

TODOs and open questions:

  • ability to push BitsAndBytesConfig
  • Do we want to save the serialized model under the name pytorch_model.bin ? I would say yes for simplicity reasons but we need to make sure that a user calls from_pretrained with load_in_8bit, hence add a warning if there is a quantization_config.json on the Hub repo + the user is not passing load_in_8bit=True.
  • Force load_in_8bit=True if there is a quantization_config.json on the Hub repo?
  • Update docs
  • Update warnings
  • Safety checkers for bnb versions
  • Add a test to check if it works using sharded fp16 weights

cc @sgugger I left few open questions, would love to hear your thoughts on these!

younesbelkada avatar Mar 15 '23 08:03 younesbelkada

The documentation is not available anymore as the PR was closed or merged.

The design is not easy enough to use. If a user saves a quantized model and pushes to the Hub, it should work directly with from_pretrained. This is why I insisted that the quantization config should be saved inside the model config. This way you won't need to have the user pass load_in_8_bit=True, as you can read it from the config.

sgugger avatar Mar 15 '23 12:03 sgugger

awesome ok, I'll work on that, so if there is a quantized config on the repo we should force-use device_map=auto & load_in_8bit in this case

younesbelkada avatar Mar 15 '23 13:03 younesbelkada

The PR is ready for review @sgugger ! This PR is not mergeable before the bnb release of course

younesbelkada avatar Mar 29 '23 12:03 younesbelkada

Thanks for the heads up! :D It should be much better now! For me the PR is ready for a review now

younesbelkada avatar Mar 31 '23 09:03 younesbelkada