Liger-Kernel icon indicating copy to clipboard operation
Liger-Kernel copied to clipboard

[Question] No gain in VRAM usage with LigerFusedLinearCrossEntropyLoss

Open mrinaldi97 opened this issue 1 month ago • 3 comments

Hello, I am writing a codebase to train transformer models (almost finished, but it's still too early to share the entire framework) and I've just added Liger-Kernel support. I was expecting a decreased RAM usage with the Fused Linear Cross Entropy implementation compared to naive torch's F.cross_entropy, however:

Torch => 22785MiB
Liger Fused => 22155MiB

Only 600MB saved. Vocab size is 32777 (maybe too small to see gain? Maybe because is not pow of 16?)

It's training a small test autoregressive model:

    "hidden_size": 768,
    "ffn_factor": 3.0,
    "num_hidden_layers": 12,
    "num_attention_heads": 12,

testing on an Nvidia 3090 GPU using Torch 2.6 cuda 12.4; training done in AMP with Pytorch Lightning, precision bf16.

Here is the transformer block with lm_head:

  class TransformerWithLMHead(nn.Module):
      """
      Adding an LM Head to TransformerWithEmbeddingHead. This is enough for Bert-like/GPT-like models.
      """
      def __init__(self,config: ModelConfig,cache=None):
          super().__init__()  
          self.cache = ensure_cache_and_registry(cache)    
          cache=self.cache           
          self.lm_head = ModuleWrapper(self.cache.registry.create("linear", "linear", in_features=config.hidden_size, out_features=config.vocab_size))
          self.transformer = TransformerWithEmbeddingHead(config,cache=cache)
          if config.tie_word_embeddings:
              self.lm_head.weight = self.transformer.embed_tokens.weight
          self.config=config
      def forward(self,x,return_type='logits',**kwargs):
          x=self.transformer(x,**kwargs)
          if return_type=='logits':
              return self.lm_head(x)
          else:
              return x

Here the relevant snippet from the training step:

        if self.loss_type=='fused':
            model_return_type = 'hidden'
            flattening_dimension = self.config.hidden_size
            loss_kwargs = {"lm_head_weight": self.model.lm_head.module.inner.weight}
            if hasattr(self.model.lm_head, "bias"):
                loss_kwargs["lm_head_bias"] = self.model.lm_head.module.inner.bias #TODO: Better way to access inner attributes of wrapped modules

And finally the way in which the Liger kernel is used:

@registry.register("loss", "cross_entropy_loss_fused", "liger", requires=["liger_kernel"], priority=0)
class LigerCrossEntropyLossFused(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        cls = _load("liger_kernel.transformers", "LigerFusedLinearCrossEntropyLoss")
        self.inner = cls(*args, **kwargs)

    def forward(self, hidden, targets, **kwargs):
        return self.inner(_input=hidden, target=targets, lin_weight=kwargs['lm_head_weight'], bias=kwargs.get("lm_head_bias", None))

Moreover, the loss diverges compared to torch:

Loss with Liger Loss with Torch

Is it implemented correctly? Thank you

mrinaldi97 avatar Nov 14 '25 11:11 mrinaldi97

Try larger batch size or seqlen, most memory saving comes from logits related memory. With liger flce, you can train with some settings that normally OOM without liger flce.

Tcc0403 avatar Nov 14 '25 19:11 Tcc0403

Hi @mrinaldi97 I would be interested in debugging this because in my understanding FLCE reduces peak mem from BXTXV to BXTXH. In terms of absolute memory, reduction should be V/H from that specific OP . Can you post pre and post pytorch profiles?

They would look something like this

Pre

Image

Post

Image

Note how the initial graph has a ~700MB sudden jump. If you zoom in it's because of these ops

Image

which are the exact ones liger's FLCE gets rid of

I think I ran this with 2x1024x32768 (BXTXV) fp32 (it's been a while so could be wrong). But 256MB comes directly from

((2*1024x32768*4)/1024)/1024

Feel free to mail, happy to help!

mayankagarwals avatar Nov 20 '25 14:11 mayankagarwals

Hi @mayankagarwals thank you a lot for the interest! I am going to DM you :) By the way at the moment the biggest issue with the integration of Liger's losses functions in my architecture is the fact that the model seems to not actually train, while with standard pytorch loss the codebase seems to work perfectly. I keep this open, so that if with the help of @mayankagarwals we find something not working with the loss itself and not just with its integration with the rest of the code we can help also the rest of the community.

mrinaldi97 avatar Nov 22 '25 17:11 mrinaldi97