bitsandbytes icon indicating copy to clipboard operation
bitsandbytes copied to clipboard

Adding Embedding4bit and StableEmbedding4bit

Open GM-git-dotcom opened this issue 1 year ago • 1 comments

Feature request

Is it possible to have two new modules, Embedding4bit and StableEmbedding4bit, at nn/modules.py? This would help enable quantizing embedding lookup tables as well.

Motivation

I was actually working on a PR at huggingface/peft that adds LoRA support for quantized embedding modules - but it turns out that bitsandbytes does not have a module like Embedding4bit as it does for Linear4bit.

Your contribution

Using Params4bit, I think it wouldn't be too difficult to make a 4-bit version for both. I have a rough implementation for Embedding4bit, and I would be happy to work on a PR for this. Any help/thoughts/comments are much appreciated! :)

# This is for the bnb version you can download currently with pip install bitsandbytes
# I noticed that the current state of the repo has stark differences, e.g. attributes like self.quant_storage are new

class Embedding4bit(torch.nn.Embedding):
    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        padding_idx: Optional[int] = None,
        max_norm: Optional[float] = None,
        norm_type: float = 2.0,
        scale_grad_by_freq: bool = False,
        sparse: bool = False,
        _weight: Optional[Tensor] = None,
        device: Optional[device] = None,
        compute_dtype=None,
        compress_statistics=True,
        quant_type='fp4' #
    ) -> None:
        super().__init__(
            num_embeddings,
            embedding_dim,
            padding_idx,
            max_norm,
            norm_type,
            scale_grad_by_freq,
            sparse,
            _weight,
            device=device
        )
        self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type) #
        self.compute_dtype = compute_dtype #
        self.compute_type_is_set = False #
        self.quant_state = None #
        
        #GlobalOptimManager.get_instance().register_module_override(
        #    self, "weight", {"optim_bits": 32}
        #)

    def reset_parameters(self) -> None:
        torch.nn.init.xavier_uniform_(self.weight)
        self._fill_padding_idx_with_zero()

    """ !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
        to make the Layer compatible with Pytorch < 1.9.
        This means that if this changes in future PyTorch releases this need to change too
        which is cumbersome. However, with this we can ensure compatibility with previous
        PyTorch releases.
    """

    def _fill_padding_idx_with_zero(self) -> None:
        if self.padding_idx is not None:
            with torch.no_grad():
                self.weight[self.padding_idx].fill_(0)

    
    def set_compute_type(self, x):
        if x.dtype in [torch.float32, torch.bfloat16]:
            # the input is in a dtype that is safe to compute in, we switch
            # to this type for speed and stability
            self.compute_dtype = x.dtype
        elif x.dtype == torch.float16:
            # we take the compoute dtype passed into the layer
            if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]):
                # single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast
                # warn the user about this
                warnings.warn('Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.')
                warnings.filterwarnings('ignore', message='.*inference.')
            if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]):
                warnings.warn('Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.')
                warnings.filterwarnings('ignore', message='.*inference or training')

    def _save_to_state_dict(self, destination, prefix, keep_vars):
        """
        save weight and bias,
        then fill state_dict with components of quant_state
        """
        super()._save_to_state_dict(destination, prefix, keep_vars)  # saving weight and bias

        if getattr(self.weight, "quant_state", None) is not None:
            for k, v in self.weight.quant_state.as_dict(packed=True).items():
                destination[prefix + "weight." + k] = v if keep_vars else v.detach()


    def forward(self, input: Tensor) -> Tensor:
      if getattr(self.weight, 'quant_state', None) is None:
        if getattr(self, 'quant_state', None) is not None:
          # the quant state got lost when the parameter got converted. This happens for example for fsdp
          # since we registered the module, we can recover the state here
          assert self.weight.shape[1] == 1
          if not isinstance(self.weight, Params4bit):
            self.weight = Params4bit(self.weight)
          self.weight.quant_state = self.quant_state
        else:
          print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
      
      if not self.compute_type_is_set:
        self.set_compute_type(input)
        self.compute_type_is_set = True
        
      inp_dtype = x.dtype
      if self.compute_dtype is not None:
        x = x.to(self.compute_dtype)
    
      emb = F.embedding(
          input,
          self.weight,
          self.padding_idx,
          self.max_norm,
          self.norm_type,
          self.scale_grad_by_freq,
          self.sparse,
      )

      return emb

GM-git-dotcom avatar Feb 24 '24 23:02 GM-git-dotcom