Liger-Kernel
Liger-Kernel copied to clipboard
`LigerFusedLinearCrossEntropyLoss` Causes Training Loss to Diverge After Reaching ~8
trafficstars
🐛 Describe the bug
Description
When using LigerFusedLinearCrossEntropyLoss (Liger FLCE) from the Liger kernel to replace torch.nn.CrossEntropyLoss, the training loss becomes unstable and diverges after reaching a certain value (~8). In contrast, the loss computed using torch.nn.CrossEntropyLoss continues to decrease smoothly.
Expected Behavior
The loss computed with LigerFusedLinearCrossEntropyLoss should decrease similarly to torch.nn.CrossEntropyLoss without significant oscillations or divergence.
Observed Behavior
- During the initial training phase, both loss functions exhibit similar behavior, and the loss decreases as expected.
- When the loss computed with
LigerFusedLinearCrossEntropyLossreaches ~8, it becomes unstable, oscillates, and diverges, as shown in the attached graph.
Screenshots/Logs
Loss curve comparison (attached):
- The orange curve shows the behavior with
torch.nn.CrossEntropyLoss(stable). - The purple curve shows the behavior with
LigerFusedLinearCrossEntropyLoss(unstable and divergent).
Additional Context
- This issue appears to be related to gradient computation or numerical stability with
LigerFusedLinearCrossEntropyLoss. - No hyperparameter changes were made between the two implementations.
Request for Assistance
- Please investigate whether there are implementation issues with
LigerFusedLinearCrossEntropyLoss. - Are there additional configurations or training parameters required to avoid instability?
Thank you for your assistance!
Reproduce
Code to Reproduce
Original compute_loss implementation (works as expected):
def compute_loss(self, hidden_states, labels):
logits = self.lm_head(hidden_states).float()
# Using torch.nn.CrossEntropyLoss for loss computation
loss_fn = torch.nn.CrossEntropyLoss()
loss = loss_fn(logits[:, :-1].reshape(-1, logits.size(-1)), labels[:, 1:].reshape(-1))
return loss
New compute_fused_loss implementation (causes instability):
def compute_fused_loss(self, hidden_states, labels):
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten tokens
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
shift_labels = shift_labels.view(-1)
lce = LigerFusedLinearCrossEntropyLoss(reduction="mean")
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
return loss
Steps to Reproduce
- Replace the original
compute_lossfunction with the newcompute_fused_lossfunction usingLigerFusedLinearCrossEntropyLoss. - Train a model using both implementations (
torch.nn.CrossEntropyLossandLigerFusedLinearCrossEntropyLoss) for comparison. - Observe the behavior of the loss curves during training.
- With
torch.nn.CrossEntropyLoss, the loss continues to decrease as expected. - With
LigerFusedLinearCrossEntropyLoss, the loss starts to oscillate and then diverges when it reaches ~8.
- With
Versions
Environment
- Liger Kernel Version:
0.3.1 - Hardware:
8 * A100 GPU - CUDA Version:
12.4 - PyTorch Version:
2.5.1+cu124 - Transformers Version:
4.46.3 - Precision:
torch.bfloat16 - Optimizer:
Zero Stage 1