Liger-Kernel
Liger-Kernel copied to clipboard
[RFC] Liger FlexChunkLoss: Alignment and Distillation loss
🚀 The feature, motivation and pitch
We want to support various alignment and distillation loss functions. Refer this PR on ORPO: #362
Progress
Alignment
- [x] ORPO https://github.com/linkedin/Liger-Kernel/pull/362
- [x] CPO https://github.com/linkedin/Liger-Kernel/pull/382
- [x] DPO https://github.com/linkedin/Liger-Kernel/pull/378
- [x] SimPO https://github.com/linkedin/Liger-Kernel/pull/386
- [x] IRPO
- [x] KTO https://github.com/linkedin/Liger-Kernel/pull/475
- [ ] f-PO
Distillation
- [ ] KL divergence
- [ ] cosine_similarity
- [ ] earth_mover_distance
- [x] JSD https://github.com/linkedin/Liger-Kernel/pull/425
- [ ] KVD
Design
Approach Overview:
The core idea is to extend the methods used in chunked Fused Linear Cross Entropy (FLCE) to various alignment algorithms. Here's how the process is structured:
- Modular Optimization Process:
- Every alignment algorithm’s optimization can be broken into three key steps:
- Linear layer computation
- Loss computation
- Gradient calculation
- Every alignment algorithm’s optimization can be broken into three key steps:
- Fused Linear and Loss Computation:
- Similar to FLCE, we aim to fuse the linear layer with the loss computation for efficiency.
- Chunking & Forward Optimization:
- Since this is the final step in the model’s forward pass, we can also compute gradients directly during the forward pass instead of waiting for a separate backward pass.
- We also chunk the input within the forward pass of the model, allowing significant reduction in peak gpu memory required.
- Torch Compile for Kernel Optimization:
- Instead of manually handling kernel-level optimizations, we let torch.compile automatically optimize kernel execution. This reduces the need for low-level optimizations while still achieving performance gains.
By combining these strategies, we efficiently optimize alignment algorithms while also simplifying development.
Key Findings
By leveraging torch.compile alongside optimization techniques like chunking, online softmax, etc, we observed close to custom triton kernel performance and reduced development time. This is why we want to introduce torch.compile as a key component of Liger. References:
- #227
- https://gist.github.com/Chillee/22cd93e11b887db1f596ab754d60a899#file-lce_benchmark-py
Interface
Have a base class FlexChunkLoss that handles chunking, accumulation and compiling strategies.
A custom loss class wraps the FlexChunkLoss and implements the loss fn that operates on a given chunk.
class Mycustomloss(FlexChunkLoss):
def loss_fn(...):
..do something here
Alternatives
No response
Additional context
No response