Liger-Kernel icon indicating copy to clipboard operation
Liger-Kernel copied to clipboard

[Question] No gain in VRAM usage with LigerFusedLinearCrossEntropyLoss

Open mrinaldi97 opened this issue 1 month ago • 3 comments

Hello, I am writing a codebase to train transformer models (almost finished, but it's still too early to share the entire framework) and I've just added Liger-Kernel support. I was expecting a decreased RAM usage with the Fused Linear Cross Entropy implementation compared to naive torch's F.cross_entropy, however:

Torch => 22785MiB
Liger Fused => 22155MiB

Only 600MB saved. Vocab size is 32777 (maybe too small to see gain? Maybe because is not pow of 16?)

It's training a small test autoregressive model:

    "hidden_size": 768,
    "ffn_factor": 3.0,
    "num_hidden_layers": 12,
    "num_attention_heads": 12,

testing on an Nvidia 3090 GPU using Torch 2.6 cuda 12.4; training done in AMP with Pytorch Lightning, precision bf16.

Here is the transformer block with lm_head:

  class TransformerWithLMHead(nn.Module):
      """
      Adding an LM Head to TransformerWithEmbeddingHead. This is enough for Bert-like/GPT-like models.
      """
      def __init__(self,config: ModelConfig,cache=None):
          super().__init__()  
          self.cache = ensure_cache_and_registry(cache)    
          cache=self.cache           
          self.lm_head = ModuleWrapper(self.cache.registry.create("linear", "linear", in_features=config.hidden_size, out_features=config.vocab_size))
          self.transformer = TransformerWithEmbeddingHead(config,cache=cache)
          if config.tie_word_embeddings:
              self.lm_head.weight = self.transformer.embed_tokens.weight
          self.config=config
      def forward(self,x,return_type='logits',**kwargs):
          x=self.transformer(x,**kwargs)
          if return_type=='logits':
              return self.lm_head(x)
          else:
              return x

Here the relevant snippet from the training step:

        if self.loss_type=='fused':
            model_return_type = 'hidden'
            flattening_dimension = self.config.hidden_size
            loss_kwargs = {"lm_head_weight": self.model.lm_head.module.inner.weight}
            if hasattr(self.model.lm_head, "bias"):
                loss_kwargs["lm_head_bias"] = self.model.lm_head.module.inner.bias #TODO: Better way to access inner attributes of wrapped modules

And finally the way in which the Liger kernel is used:

@registry.register("loss", "cross_entropy_loss_fused", "liger", requires=["liger_kernel"], priority=0)
class LigerCrossEntropyLossFused(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        cls = _load("liger_kernel.transformers", "LigerFusedLinearCrossEntropyLoss")
        self.inner = cls(*args, **kwargs)

    def forward(self, hidden, targets, **kwargs):
        return self.inner(_input=hidden, target=targets, lin_weight=kwargs['lm_head_weight'], bias=kwargs.get("lm_head_bias", None))

Moreover, the loss diverges compared to torch:

Loss with Liger Loss with Torch

Is it implemented correctly? Thank you

mrinaldi97 avatar Nov 14 '25 11:11 mrinaldi97