transformers icon indicating copy to clipboard operation
transformers copied to clipboard

`bitsandbytes` - `Linear8bitLt` integration into `transformers` models

Open younesbelkada opened this issue 2 years ago β€’ 19 comments

What does this PR do?

Adding the mixed 8bit quantization for large language models! πŸš€ This feature could reduce the size of the large models by up to 2, without a high loss in precision Paper and main implementations from: @TimDettmers

Usage:

Anyone with a GPU that supports mixed 8 bit quantization could load a model using AutoModel.from_pretrained(xxx, load_in_8bit=True, device_map="auto") And works like charm. Could work on any HF model!

Requirements

Needs the latest version of bitsandbytes (that is compiled manually) and accelerate

TODOs:

  • [x] Add custom tests
  • [x] Discuss potential improvements
  • [x] Verify that the weights are still in 8bit after the loading (once there are more advances on Tim's side)
  • [x] Add documentation (Younes first and then Tim)
  • [x] Add a demo / few lines to explain how to use it
  • [ ] Add flag that loads directly to 8bit @TimDettmers

Resources:

  • WIP branch of bitsandbytes: https://github.com/TimDettmers/bitsandbytes/tree/cublaslt

Many thanks to @justheuristic and @TimDettmers !! πŸŽ‰

younesbelkada avatar Jun 27 '22 18:06 younesbelkada

The documentation is not available anymore as the PR was closed or merged.

cc-ing also @michaelbenayoun in case you want to have a look as well ;)

younesbelkada avatar Jun 28 '22 09:06 younesbelkada

Nice, thanks for working on it @younesbelkada! Also quite interested in the feature. I'd be particularly interested in seeing a bit of documentation so that we may understand better how it works under the hood and how to use the feature to its best.

Thanks!

LysandreJik avatar Jun 28 '22 12:06 LysandreJik

Hi all! Just to summarise a bit about what is happening and the solution we came up to implement this! In the previous version, we found out 2 major bugs: 1- the function set_module_tensor_to_device seems to overwrite the Int8Params modules by nn.Parameter modules. 2- init_empty_weights seems also to replace the Int8Params modules by nn.Parameter modules.

I see two solutions to this 1- Open a PR in accelerate to support the correct overwriting into Int8Params class as the following: https://github.com/huggingface/accelerate/compare/main...TimDettmers:accelerate:integration_8bit - only 2 functions are modified and should not break backward compatibility but I am not sure 2- Manually redefine the functions set_module_tensor_to_device and init_empty_weightsas two new function set_module_8bit_tensor_to_device and init_8bit_empty_weights as proposed in this PR.

I personally found the option 1 cleaner but the option 2 might be safer for accelerate - Let us know what do you think ! cc @LysandreJik @sgugger @TimDettmers

younesbelkada avatar Jul 12 '22 09:07 younesbelkada

Thank you very much for your comments! has_fp16_weights comes from the class bnb.Int8Params that is currently being developed in a WIP branch that should be merged soon on the main branch of bitsandbytes. Basically the logic behind it is that if the module contain this attribute then it has to be a bnb.Int8Params module. I will refactor the code with your proposed changes and ask for a second batch of review πŸš€

younesbelkada avatar Jul 12 '22 13:07 younesbelkada

I think before merging we need:

  • [x] Memory footprint benchmarking
  • [x] Infrence speed benchmarking
  • [x] lm-eval benchmarking for large models (it has been done for small models)
  • [x] Merging the WIP branch of bitsandbytes into main

younesbelkada avatar Jul 13 '22 16:07 younesbelkada

Added another PR to support int8 quantization + accelerate on multi-GPU setup here: https://github.com/huggingface/accelerate/pull/539 !

younesbelkada avatar Jul 20 '22 11:07 younesbelkada

Thanks @sgugger for your review ! Fixed the suggestions ;) I think that we are good to go to merge https://github.com/huggingface/accelerate/pull/539 if you don't mind πŸ™ I just need to wait the release of bitsandbytes to be more stable (facing some issues when installing the library but should be fixed very soon, I am syncing with @TimDettmers). Once this is fixed I think that we should be good to go for merging πŸš€

younesbelkada avatar Jul 27 '22 09:07 younesbelkada

Merged the PR in Accelerate! Don't forget to add some documentation and also setup some tests for this so it doesn't get broken by future PRs :-)

sgugger avatar Jul 27 '22 11:07 sgugger

TODOs:

  • [x] Have a working colab demo for inference
  • [x] Add more documentation
  • [x] Implement tests

younesbelkada avatar Jul 28 '22 12:07 younesbelkada

Before moving forward, I would like to have a comment from @michaelbenayoun @mfuntowicz and @echarlaix

About this PR

We replace all the nn.Linear modules by the bnb.Linear8bitLt modules from the recent release of bitsandbytes that proposes a new post-training quantization technique for 0 performance degradation on large-scale models (>1b parameters). With that we have managed to fit BLOOM-176B on 3xA100 80GB instead of 6xA100 GB with no performance degradation.

About the mixed quantization method in few words

In this technique the operations on the outliers are done in fp16 and the rest of the operations are done in int8 to achieve 0 performance degradation on large-scale models.

Usage

This does not run on CPU, you will need a GPU that supports 8-bit core tensors operations (T4 and A100) to make it run. Here is a tutorial on Google Colab on how to run the mixed-int8 model: https://colab.research.google.com/drive/1qOjXfQIAULfKvZqwCen8-MoWKGdSatZ4#scrollTo=YJlldexxwnhM

younesbelkada avatar Aug 01 '22 10:08 younesbelkada

Can confirm the slow tests that I have designed are passing on my testing machine (2x Tesla T4 15GB). But for now it is not possible to load saved int8 checkpoints because you need to load the quantization statistics that are not saved when doing model.state_dict() in bitsandbytes. For now I propose to just raise an error message if int8 weights are loaded and tell users that the feature is not supported (as proposed in 1326a42795033410dae6c5a8a07b81f12ee7a41c). No strong opinions but I personally advocate to keep this feature inside transformers since the method relies also on accelerate + an additional lib (bitsandbytes), but I am not the best knowledgable person regarding optimum integration that might be a bit different than the transformers one. cc @sgugger @mfuntowicz @TimDettmersπŸ™

younesbelkada avatar Aug 01 '22 15:08 younesbelkada

Thank you for all the work on this PR @younesbelkada, @sgugger, @michaelbenayoun!

Regarding the transformers vs optimum question: From my understanding of the libraries, I think if people want to deploy models or run them with high efficiency optimum seems to be the right tool, whereas general purpose "inefficient" access of models is more suitable for transformers.

As such, I think it's best to keep this feature in transformers. I think it fits better into there since it is not meant for fast inference but memory-efficient inference for as many use-cases as possible.

TimDettmers avatar Aug 02 '22 04:08 TimDettmers

Forgive me for jumping the gun -

On Colab(T4, 12G RAM) I tried:

!nvidia-smi

| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   58C    P8    10W /  70W |      0MiB / 15109MiB |      0%      Default |
|  No running processes found                                                 |

Then

!pip install https://github.com/younesbelkada/transformers/archive/refs/heads/integration-8bit.zip accelerate 
!pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda112

Loading model with

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-2B-mono")
model = AutoModelForCausalLM.from_pretrained("Salesforce/codegen-2B-mono", load_in_8bit=True, device_map="auto")

And got this error:

RuntimeError                              Traceback (most recent call last)
[<ipython-input-8-40073518cc86>](https://localhost:8080/#) in <module>()
----> 1 model = AutoModelForCausalLM.from_pretrained("Salesforce/codegen-2B-mono", load_in_8bit=True, device_map="auto")

7 frames
[/usr/local/lib/python3.7/dist-packages/transformers/models/auto/auto_factory.py](https://localhost:8080/#) in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    444         elif type(config) in cls._model_mapping.keys():
    445             model_class = _get_model_class(config, cls._model_mapping)
--> 446             return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
    447         raise ValueError(
    448             f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"

[/usr/local/lib/python3.7/dist-packages/transformers/modeling_utils.py](https://localhost:8080/#) in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
   2282         # Dispatch model with hooks on all devices if necessary
   2283         if device_map is not None:
-> 2284             dispatch_model(model, device_map=device_map, offload_dir=offload_folder)
   2285 
   2286         if output_loading_info:

[/usr/local/lib/python3.7/dist-packages/accelerate/big_modeling.py](https://localhost:8080/#) in dispatch_model(model, device_map, main_device, state_dict, offload_dir, offload_buffers, preload_module_classes)
    246         offload_buffers=offload_buffers,
    247         weights_map=weights_map,
--> 248         preload_module_classes=preload_module_classes,
    249     )
    250     model.hf_device_map = device_map

[/usr/local/lib/python3.7/dist-packages/accelerate/hooks.py](https://localhost:8080/#) in attach_align_device_hook_on_blocks(module, execution_device, offload, weights_map, offload_buffers, module_name, preload_module_classes)
    446             place_submodules=True,
    447         )
--> 448         add_hook_to_module(module, hook)
    449         attach_execution_device_hook(module, execution_device[module_name])
    450     elif module_name in execution_device:

[/usr/local/lib/python3.7/dist-packages/accelerate/hooks.py](https://localhost:8080/#) in add_hook_to_module(module, hook)
    136         module._old_forward = old_forward
    137 
--> 138     module = hook.init_hook(module)
    139     module._hf_hook = hook
    140 

[/usr/local/lib/python3.7/dist-packages/accelerate/hooks.py](https://localhost:8080/#) in init_hook(self, module)
    219         if not self.offload and self.execution_device is not None:
    220             for name, _ in named_module_tensors(module, recurse=self.place_submodules):
--> 221                 set_module_tensor_to_device(module, name, self.execution_device)
    222         elif self.offload:
    223             self.original_devices = {

[/usr/local/lib/python3.7/dist-packages/accelerate/utils/modeling.py](https://localhost:8080/#) in set_module_tensor_to_device(module, tensor_name, device, value)
    128         module._buffers[tensor_name] = new_value
    129     else:
--> 130         new_value = nn.Parameter(new_value, requires_grad=old_value.requires_grad)
    131         module._parameters[tensor_name] = new_value
    132 

[/usr/local/lib/python3.7/dist-packages/torch/nn/parameter.py](https://localhost:8080/#) in __new__(cls, data, requires_grad)
     40         t = data.detach().requires_grad_(requires_grad)
     41         if type(t) is not type(data):
---> 42             raise RuntimeError(f"Creating a Parameter from an instance of type {type(data).__name__} "
     43                                "requires that detach() returns an instance of the same type, but return "
     44                                f"type {type(t).__name__} was found instead. To use the type as a "

RuntimeError: Creating a Parameter from an instance of type Int8Params requires that detach() returns an instance of the same type, but return type Tensor was found instead. To use the type as a Parameter, please correct the detach() semantics defined by its __torch_dispatch__() implementation.

Interestingly on AWS Sagemaker(T4, 16G RAM) -

!pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda114

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-6B-mono")
model = AutoModelForCausalLM.from_pretrained("Salesforce/codegen-6B-mono", load_in_8bit=True, device_map="auto")

got me

---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
/tmp/ipykernel_141/3855166932.py in <cell line: 1>()
----> 1 model = AutoModelForCausalLM.from_pretrained("Salesforce/codegen-6B-mono", load_in_8bit=True, device_map="auto")

~/.conda/envs/default/lib/python3.9/site-packages/transformers/models/auto/auto_factory.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    444         elif type(config) in cls._model_mapping.keys():
    445             model_class = _get_model_class(config, cls._model_mapping)
--> 446             return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
    447         raise ValueError(
    448             f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"

~/.conda/envs/default/lib/python3.9/site-packages/transformers/modeling_utils.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
   2177             init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts
   2178         elif load_in_8bit:
-> 2179             init_contexts = [init_empty_weights()]  # Force enable init empty weights
   2180             logger.info("Detected 8-bit loading: activating 8-bit loading for this model")
   2181         elif low_cpu_mem_usage:

NameError: name 'init_empty_weights' is not defined

I suppose the 2nd case may have something to do with environment setup - but what would trigger the first issue?

Thanks,

cnbeining avatar Aug 04 '22 06:08 cnbeining

Hi @cnbeining ! Thanks for your interest in this feature and happy to see that you are already excited to run it on Codegen! πŸš€ Initially your problem is related to accelerate that you are installing. Make sure you install the latest version from source using a command like:

pip install git+https://github.com/huggingface/accelerate.git@24c28a1adc284db0126b7c17ebef275597ddc6b7

With 24c28a1adc284db0126b7c17ebef275597ddc6b7 being the latest commit hash from accelerate. The most recent release (aka accelerate library that you will get from pip install accelerate) is not compatible with this PR at the time I wrote this message. Therefore you will need the latest version of it.

However, when using load_in_8bit, torch_dtype=torch.float16 is internally called. It happens that there might be a small bug in Codegen when using torch_dtype=torch.float16 that we propose to fix in https://github.com/huggingface/transformers/pull/18467 . if you are interested to reproduce the issue you can run this small snippet:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-2B-mono")
model = AutoModelForCausalLM.from_pretrained("Salesforce/codegen-2B-mono", device_map="auto", torch_dtype=torch.float16)

text = "def quicksort(l):"

encoded_input = tokenizer(text, return_tensors='pt')
output_sequences = model.generate(input_ids=encoded_input['input_ids'], attention_mask=encoded_input['attention_mask'])
print(tokenizer.decode(output_sequences[0], skip_special_tokens=True))

Since this might take time to be merged and as I saw that you wanted to run on Google Colab I made a special branch that you can build from Colab and should work (tested it) here: Open In Colab. Just run those cells and everything should work.

If you follow the same installation instructions as this Colab I think that everything should work smoothly in SageMaker as well but we never know! Let us know if this helps, and happy to help you again if necessary πŸ’ͺ Also if you face any other issues, I think that it would be better to move this discussion into an issue! πŸ›

Thanks Younes

younesbelkada avatar Aug 04 '22 08:08 younesbelkada

@LysandreJik I have a question regarding slow tests for this feature! I prefer to build another Docker image for these tests and run them separately because it happens sometimes that the import of bitsandbytes fails on some specific configurations. We found an issue that will be fixed on bitsandbytes asap but I think that having a separate image and running the tests independently is safer to not affect other tests. Since bitsandbytes is always being imported if it is available if the docker image installs it all tests will fail at import time. I can also try to come up with a solution where we import this library only if load_in_8bit is triggered. What do you think is the best in this case?

younesbelkada avatar Aug 05 '22 10:08 younesbelkada

Slow tests are now all passing on our docker image with the latest fix of bitsandbytes I would love to have a potential final round of review! cc @sgugger @LysandreJik

younesbelkada avatar Aug 08 '22 08:08 younesbelkada

Thanks for the review! Going to do a last sanity check - testing with Docker and see if the slow tests passes on our Docker and merge once it's green! 🟒

younesbelkada avatar Aug 09 '22 15:08 younesbelkada

GJ!

Non blocking comment: How about incorporating (optional) bnb.nn.StableEmbedding as recommended by authors or added benefit is limited?

cnbeining avatar Aug 09 '22 18:08 cnbeining

Thanks @cnbeining ! I think that this can be done in a separate PR since we need to release the beta version of this feature probably ASAP! Also I am not sure how the current implementation will handle tied weights if we replace Embedding layers with StableEmbedding. So this needs further tests/investigations

younesbelkada avatar Aug 10 '22 05:08 younesbelkada

Yeah let's get this rolled out to unleash GPT-J-6B and CodeGen to ordinary folks :-) I will continue with my testing with StableEmbedding and will report results as they come by.

Again thanks so much for all the effort!

cnbeining avatar Aug 10 '22 06:08 cnbeining

Great that would be awesome! I would be definitely interested in seeing the results and comparison against the current approach (aka without StableEmbedding) Let's maybe keep the results in this thread even after merging the PR

younesbelkada avatar Aug 10 '22 06:08 younesbelkada

Ultimate checks are passing: https://github.com/huggingface/transformers/actions/runs/2830688091 Merging!

younesbelkada avatar Aug 10 '22 07:08 younesbelkada

Looks nice, will try it out :)

fxmarty avatar Aug 10 '22 09:08 fxmarty

You can check the Google Colab: https://colab.research.google.com/drive/1qOjXfQIAULfKvZqwCen8-MoWKGdSatZ4#scrollTo=W8tQtyjp75O_ to see how to run it! We will publish that today with the beta release on Twitter

younesbelkada avatar Aug 10 '22 09:08 younesbelkada

Great work! Can big models such us models used in the example colab be fine-tuned just loading it as int8? Are you thinking about release a colab for fine-tuning a model not just for inference? Thanks in adavace

mrm8488 avatar Aug 10 '22 12:08 mrm8488

Thanks for the remark @mrm8488 ! Indeed it would be very nice to have a fine-tuning demo on colab After discussing with @TimDettmers it appears that the current implementation would support classic torch optimizers. Also I think that @justheuristic has some experience with finetuning int8 models using Linear8bitLt modules for prompt tuning ;) so I will let him answer on the feasibility of that! πŸš€

younesbelkada avatar Aug 10 '22 12:08 younesbelkada

tl;dr soon :)

Right now you can fine-tune with bitsandbytes, but it's gonna take up the same memory as 16bit - but we hope this can soon be fixed.

@timdettmers and @dbaranchuk are working on memory-efficient fine-tuning on the bitsandbytes side. After they're done, you'll be able to write trainable adapters, soft-prompts and other parameter-efficient methods.

There is also a group in BigScience that works on 8-bit finetuning for very large models (think bloom-176B) in colab, but we are still polishing the code. I'll tag you once it becomes public.

justheuristic avatar Aug 10 '22 12:08 justheuristic

@younesbelkada, thank you for integrating this awesome feature - may I suggest that all these cool features will remain hidden unless we expose them in the docs where users are likely to search for those and not in the API docs.

I propose to add a new section at https://huggingface.co/docs/transformers/main/perf_train_gpu_one so that those searching for performance improvement will find it. Thank you!

stas00 avatar Aug 11 '22 19:08 stas00

Thanks for the comment ! Sounds really good for me πŸ’ͺ I was planning to open a PR by the beginning of next week to add the link to blogpost + paper, I will most likely use this PR to propose your refactoring as well

younesbelkada avatar Aug 11 '22 22:08 younesbelkada

Thanks for the comment ! Sounds really good for me πŸ’ͺ I was planning to open a PR by the beginning of next week to add the link to blogpost + paper, I will most likely use this PR to propose your refactoring as well

Hi @younesbelkada, I want to try bitsandbytes on an 8x A6000 server (CUDA version 11.3) with BLOOM. Unfortunately, the following error throws out.

RuntimeError: Creating a Parameter from an instance of type Int8Params requires that detach() returns an instance of the same type, but return type Tensor. was found instead. To use the type as a Parameter, please correct the detach() semantics defined by __torch_dispatch__() implementation.

I use the code downloaded from Colab using the model.generate() way for inference instead of the pipeline from HuggingFace. Do you know how to solve the issue? I installed bitsandbytes==0.31.8 from https://pypi.org/project/bitsandbytes/; the latest transformers package from the master branch also installed the latest Accelerate from pip.

pai4451 avatar Aug 13 '22 03:08 pai4451

@pai4451 the code has not been released to PyPI yet - you probably want to use pip install git+https://github.com/huggingface/transformers.git to get the HEAD that includes this PR.

cnbeining avatar Aug 13 '22 04:08 cnbeining

@pai4451 the code has not been released to PyPI yet - you probably want to use pip install git+https://github.com/huggingface/transformers.git to get the HEAD that includes this PR.

@cnbeining I didn’t install transformers from PyPI, instead I installed from this repo.

pai4451 avatar Aug 13 '22 04:08 pai4451

Hi @pai4451 ! Thanks a lot for your message! This error is related to accelerate, I have run the colab demo this morning and everything seems to work fine. I do think that you most likely didn't installed the correct version of accelerate as it happened to @cnbeining before. Could you please share with us the output of pip list ? If you see that accelerate version is below 0.11.x then you should re-install it with pip install --force accelerate Let us know if this works!

younesbelkada avatar Aug 13 '22 06:08 younesbelkada

Hi @pai4451 ! Thanks a lot for your message! This error is related to accelerate, I have run the colab demo this morning and everything seems to work fine. I do think that you most likely didn't installed the correct version of accelerate as it happened to @cnbeining before. Could you please share with us the output of pip list ? If you see that accelerate version is below 0.11.x then you should re-install it with pip install --force accelerate Let us know if this works!

@younesbelkada The error is really just related to accelerate. After upgrading accelerate to 0.12 the issue is solved. Thanks for providing such a wonderful feature.

pai4451 avatar Aug 14 '22 14:08 pai4451

@pai4451 No problem at all! I am very happy that you made it run! Let us know if you face into any other issue.

younesbelkada avatar Aug 14 '22 17:08 younesbelkada

Hi @younesbelkada, thank you again for bitandbytes integration with transformers models. I wonder if it is possible to use a similar way for DeepSpeed on int8 quantization with the BLOOM model for inference just as bitandbytes to transformers? Or is there any chance of loading the model using bitandbytes with DeepSpeed? DeepSpeed has advantages in terms of model loading and inference speed. Do you have any thoughts on how I can achieve that? I appreciate any comments you can provide.

Currently, I'm trying on the following DeepSpeed inference script. But DeepSpeed load the model with deepspeed.init_inference instead of the from_pretrained method in transformers.

pai4451 avatar Aug 16 '22 02:08 pai4451

deepspeed-inference+int8 is being worked on, please give us a bit of time.

As you discovered the ds-inference script, it'll be shortly updated to support int8.

stas00 avatar Aug 16 '22 04:08 stas00

deepspeed-inference+int8 is being worked on, please give us a bit of time.

As you discovered the ds-inference script, it'll be shortly updated to support int8.

Looking forward to try it, and maybe "illegal memory access" issue could be resolved by the way because of GPU memory consumption reduction on each.

pohunghuang-nctu avatar Aug 16 '22 06:08 pohunghuang-nctu