pytorch-lightning
pytorch-lightning copied to clipboard
Dequantize model with Bitsandbytes precision plugin
Description & Motivation
Hi there 👋
As @carmocca proposed, I would like to add functionality for BNB precision plugin to dequantize weights. Why it's needed? If anyone want's to use quantization during training, to lower the VRAM requirements, but afterward also wants to save a model in a dequantized form, so it could be used anywhere else. But I have trouble implementing it. The trouble is not of a technical nature, thanks to @awaelchli I have an idea how to do it, but more of a design one. I'm not sure when the process of dequantization should be executed. So before proceeding, I decided to create this issue in order to confirm that I'm steering in the right direction.
Pitch
Note: dequantize logic will be in BNB precision plugin, there is no doubt about that. The list below only describes when the dequantization is called/executed.
Option 1: state_dict hooks
This variant mimics of what is currently used in BNB precision plugin.
We can have pre_
and a regular state_dict_hook
. So whenever we call model.state_dict
all the quantized weights will be dequantized.
- With
self._register_load_state_dict_pre_hook
we dequantize weights in-place, so in the end astate_dict
contains them. - With
self._register_state_dict_hook
we take quantized weights from the state_dict and replace the value with a dequantized variant.
The first variant changes a model, which is definitely an unexpected behavior when you call model.state_dict
.
With the second variant, a state_dict will contain values that aren't in the model, as the model will still contain quantized variants. That not only adds a level of discrepancy, but also leads to storing both dequantized (in state_dict) and quantized (in a model) values.
Another issue with these variants is that dequantization will be called whenever we call .state_dict
which is not ideal in some situations.
Not the best option in my opinion.
Option 1.1: .state_dict with dequantize argument
When a model, that is wrapped in Fabric, calls .state_dict
, first a wrapped variant is executed.
So we can have an additional argument and if it is provided - dequantize a model in-place.
But it doesn't feel right, dequantization doesn't belong here, in my opinion.
Option 2: dequantize during model saving
As I described above, the main reason why you might want to dequantize a model is to have an ability to save it in a format that can be used elsewhere.
So, naturally, the best place for dequantization would be in fabric.save.
This method should have an additional argument - dequantize
.
But here is a catch: you cannot dequantize the model right away. In this case a user most likely will get OOM. The main reason why anyone might want to use quantization is because of the limit amount of video memory.
So an incremental save should be added: take a layer, dequantize it, save it, take another layer, ...
So overall my solution is to:
- add dequantize logic to BNB precision plugin.
- add dequantize argument to fabric.save.
- dequantize and incrementally save layer-by-layer.
But maybe there is an easier solution that I don't see?
Alternatives
No response
Additional context
No response
cc @borda @tchaton @justusschock @awaelchli @carmocca