Liger-Kernel icon indicating copy to clipboard operation
Liger-Kernel copied to clipboard

`LigerFusedLinearCrossEntropyLoss` Causes Training Loss to Diverge After Reaching ~8

Open penghui-yang opened this issue 10 months ago • 7 comments
trafficstars

🐛 Describe the bug

Description

When using LigerFusedLinearCrossEntropyLoss (Liger FLCE) from the Liger kernel to replace torch.nn.CrossEntropyLoss, the training loss becomes unstable and diverges after reaching a certain value (~8). In contrast, the loss computed using torch.nn.CrossEntropyLoss continues to decrease smoothly.

Expected Behavior

The loss computed with LigerFusedLinearCrossEntropyLoss should decrease similarly to torch.nn.CrossEntropyLoss without significant oscillations or divergence.

Observed Behavior

  • During the initial training phase, both loss functions exhibit similar behavior, and the loss decreases as expected.
  • When the loss computed with LigerFusedLinearCrossEntropyLoss reaches ~8, it becomes unstable, oscillates, and diverges, as shown in the attached graph.

Screenshots/Logs

Loss curve comparison (attached):

  • The orange curve shows the behavior with torch.nn.CrossEntropyLoss (stable).
  • The purple curve shows the behavior with LigerFusedLinearCrossEntropyLoss (unstable and divergent).

image

Additional Context

  • This issue appears to be related to gradient computation or numerical stability with LigerFusedLinearCrossEntropyLoss.
  • No hyperparameter changes were made between the two implementations.

Request for Assistance

  • Please investigate whether there are implementation issues with LigerFusedLinearCrossEntropyLoss.
  • Are there additional configurations or training parameters required to avoid instability?

Thank you for your assistance!

Reproduce

Code to Reproduce

Original compute_loss implementation (works as expected):

def compute_loss(self, hidden_states, labels):
    logits = self.lm_head(hidden_states).float()
    # Using torch.nn.CrossEntropyLoss for loss computation
    loss_fn = torch.nn.CrossEntropyLoss()
    loss = loss_fn(logits[:, :-1].reshape(-1, logits.size(-1)), labels[:, 1:].reshape(-1))
    return loss

New compute_fused_loss implementation (causes instability):

def compute_fused_loss(self, hidden_states, labels):
    shift_hidden_states = hidden_states[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()

    # Flatten tokens
    shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
    shift_labels = shift_labels.view(-1)

    lce = LigerFusedLinearCrossEntropyLoss(reduction="mean")
    loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
    return loss

Steps to Reproduce

  1. Replace the original compute_loss function with the new compute_fused_loss function using LigerFusedLinearCrossEntropyLoss.
  2. Train a model using both implementations (torch.nn.CrossEntropyLoss and LigerFusedLinearCrossEntropyLoss) for comparison.
  3. Observe the behavior of the loss curves during training.
    • With torch.nn.CrossEntropyLoss, the loss continues to decrease as expected.
    • With LigerFusedLinearCrossEntropyLoss, the loss starts to oscillate and then diverges when it reaches ~8.

Versions

Environment

  • Liger Kernel Version: 0.3.1
  • Hardware: 8 * A100 GPU
  • CUDA Version: 12.4
  • PyTorch Version: 2.5.1+cu124
  • Transformers Version: 4.46.3
  • Precision: torch.bfloat16
  • Optimizer: Zero Stage 1

penghui-yang avatar Jan 04 '25 09:01 penghui-yang