bitsandbytes
bitsandbytes copied to clipboard
Adding Embedding4bit and StableEmbedding4bit
Feature request
Is it possible to have two new modules, Embedding4bit
and StableEmbedding4bit
, at nn/modules.py
? This would help enable quantizing embedding lookup tables as well.
Motivation
I was actually working on a PR at huggingface/peft that adds LoRA support for quantized embedding modules - but it turns out that bitsandbytes
does not have a module like Embedding4bit
as it does for Linear4bit
.
Your contribution
Using Params4bit
, I think it wouldn't be too difficult to make a 4-bit version for both. I have a rough implementation for Embedding4bit
, and I would be happy to work on a PR for this. Any help/thoughts/comments are much appreciated! :)
# This is for the bnb version you can download currently with pip install bitsandbytes
# I noticed that the current state of the repo has stark differences, e.g. attributes like self.quant_storage are new
class Embedding4bit(torch.nn.Embedding):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
_weight: Optional[Tensor] = None,
device: Optional[device] = None,
compute_dtype=None,
compress_statistics=True,
quant_type='fp4' #
) -> None:
super().__init__(
num_embeddings,
embedding_dim,
padding_idx,
max_norm,
norm_type,
scale_grad_by_freq,
sparse,
_weight,
device=device
)
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type) #
self.compute_dtype = compute_dtype #
self.compute_type_is_set = False #
self.quant_state = None #
#GlobalOptimManager.get_instance().register_module_override(
# self, "weight", {"optim_bits": 32}
#)
def reset_parameters(self) -> None:
torch.nn.init.xavier_uniform_(self.weight)
self._fill_padding_idx_with_zero()
""" !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
to make the Layer compatible with Pytorch < 1.9.
This means that if this changes in future PyTorch releases this need to change too
which is cumbersome. However, with this we can ensure compatibility with previous
PyTorch releases.
"""
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def set_compute_type(self, x):
if x.dtype in [torch.float32, torch.bfloat16]:
# the input is in a dtype that is safe to compute in, we switch
# to this type for speed and stability
self.compute_dtype = x.dtype
elif x.dtype == torch.float16:
# we take the compoute dtype passed into the layer
if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]):
# single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast
# warn the user about this
warnings.warn('Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.')
warnings.filterwarnings('ignore', message='.*inference.')
if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]):
warnings.warn('Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.')
warnings.filterwarnings('ignore', message='.*inference or training')
def _save_to_state_dict(self, destination, prefix, keep_vars):
"""
save weight and bias,
then fill state_dict with components of quant_state
"""
super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias
if getattr(self.weight, "quant_state", None) is not None:
for k, v in self.weight.quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
def forward(self, input: Tensor) -> Tensor:
if getattr(self.weight, 'quant_state', None) is None:
if getattr(self, 'quant_state', None) is not None:
# the quant state got lost when the parameter got converted. This happens for example for fsdp
# since we registered the module, we can recover the state here
assert self.weight.shape[1] == 1
if not isinstance(self.weight, Params4bit):
self.weight = Params4bit(self.weight)
self.weight.quant_state = self.quant_state
else:
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
if not self.compute_type_is_set:
self.set_compute_type(input)
self.compute_type_is_set = True
inp_dtype = x.dtype
if self.compute_dtype is not None:
x = x.to(self.compute_dtype)
emb = F.embedding(
input,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
return emb