Liger-Kernel
Liger-Kernel copied to clipboard
Support offline `logits` for teacher model
🚀 The feature, motivation and pitch
In knowledge distillation, it has better efficiency to add support for pre-computed logits/logprobs offline in teacher model beforehand. Rather than load and forward the teacher outputs inside the kernel.
Some other thoughts on using logits or logprobs?
We scaled temperature here.
As @winglian mentioned here.
I'd actually like to see both a logit and logprob implementation since it's easy to get logprobs offline from vllm and that is a faster way to generate the dataset.
So rather than having to have the teacher model loaded during training, depending on the workload type, it can be faster and more compute efficient to pre-compute the logins/logprobs offline beforehand. However, vllm and sglang only provide the logprobs, and that's not easily back-calculated to logits.
While @shivam15s pointed out the concern regarding temperature scaled logprobs in here
curious if vllm/sglang support temperature scaled logprobs. This would be needed to enable https://github.com/huggingface/trl/blob/9c5388b69e0842f76edc46a2ff9d0b51e1db4337/trl/trainer/gkd_trainer.py#L174
Besides, @Tcc0403 suggested that log-space is the right way to go in here. For my understanding, I agree with this idea given temperature=1.
Sorry for the misleading question and late response. Passing logpbs is totally fine, it's actually better that it can avoid underflow issues in the log-space. Torch's KLDivLoss also expect inputs in the log-space, and the extra amount of calculation from softmax to logsoftmax shouldn't be an issue anyway. So if most APIs are expecting input as logpbs, then I think it's the way to go.
In my opinion, I think it's good to support offline forwarded value (e.g., logits) for teacher model beforehand. However, I’m unsure how we should support log_probs/probs as args in ditillation_loss_fn? Since multiple input vectors can yield the same output probabilities due to the normalization step, softmax is not invertible in a strict sense. In conclusion it's hard to scale on these values (after softmax) by temperature.
Alternatives
No response
Additional context
No response
I am currently conducting research related to KLD, although it is not about distillation. I am facing difficulties because I need to access logits in a limited GPU environment. It would be great if liger_kernel could support logits... Unfortunately, it is not realistically feasible for me to modify the entire codebase to selectively apply liger_kernel only where needed. I apologize for that. Is there any possible solution?
I am currently conducting research related to KLD, although it is not about distillation. I am facing difficulties because I need to access logits in a limited GPU environment. It would be great if liger_kernel could support logits...
@YooSungHyun Hi, can you be more specific about what logits you mean? Is it related to the offline reference logits proposed in this issue?
Unfortunately, it is not realistically feasible for me to modify the entire codebase to selectively apply liger_kernel only where needed.
Liger kernel monkey patch can selectively apply to whatever module you want to replace. For instance, apply_liger_kernel_to_llama(rope=True, swiglu=True, cross_entropy=True, fused_linear_cross_entropy=False, rms_norm=False).
If you are talking about something else, please let me know.
@Tcc0403 thx for reply
This is in the same context as the "offline reference logits" being discussed.
What I mean by logits are the outputs after the lm_head. For KL divergence training, the logits are typically log-softmaxed and compared. However, when debugging, I found that while the cross-entropy loss is properly computed, the logits are not included in the outputs, making it impossible to perform KL divergence.
Since cross-entropy loss only considers the conditional probability of the label vocab, I am experimenting with computing KL divergence between the full vocab probabilities of the reference model and the training model. To run this experiment on low-cost GPUs, I want to apply the Liger kernel across most of the forward pass, but I need the logits to remain available right before loss computation.
Currently, I’m using use_liger=True in Transformers with DeepSpeed, but I’m not sure how to declare things natively. That’s why, although I was told:
Liger kernel monkey patch can selectively apply to whatever module you want to replace. For instance,
apply_liger_kernel_to_llama(rope=True, swiglu=True, cross_entropy=True, fused_linear_cross_entropy=False, rms_norm=False),
I'm having trouble applying this method properly. If I apply the method above, would I be able to use the Liger kernel in most parts while still obtaining the logits?
@YooSungHyun Thanks for your response
If I apply the method above, would I be able to use the Liger kernel in most parts while still obtaining the logits?
Yes, disable fused_linear_cross_entropy to materialize full logits, bringing back logits. However, fused_linear_cross_entropy is the main component of liger kernel to reduce memory peak. Without it, the memory reduction won't be that significant.
Currently, I’m using
use_liger=Truein Transformers with DeepSpeed, but I’m not sure how to declare things natively. That’s why, although I was told:Liger kernel monkey patch can selectively apply to whatever module you want to replace. For instance, apply_liger_kernel_to_llama(rope=True, swiglu=True, cross_entropy=True, fused_linear_cross_entropy=False, rms_norm=False),
I'm having trouble applying this method properly.
use_liger=True in Transformers is equivalent to apply_liger_kernel_to_your_model() with default arguments, and fused_linear_cross_entropy is default to True.
Remove use_liger=True and manually patch instead (take llama for example):
# 1. Specify exactly which kernels are applied
apply_liger_kernel_to_llama(
rope=True,
swiglu=True,
cross_entropy=False,
fused_linear_cross_entropy=False,
rms_norm=True
)
# 2. Instantiate patched model
model = transformers.AutoModelForCausalLM("path/to/llama/model")
Also, there's LigerFusedLinearJSDLoss that can also perform KL divergence and achieve large memory reduction:
https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/chunked_loss/jsd_loss.py
But currently there's no simple way to use it by just passing an argument. It requires users to compose their own models.
We are trying to integrate it(GKDTrainer) into TRL, https://github.com/huggingface/trl/issues/2495. If it meets your need, I will work on it.
@Tcc0403 Ah, I’ve also been very interested in JSD, and it seems you’ve been working on it! Thank you so much. The topic I’m currently researching is actually something that extends beyond JSD. (I can’t share the details just yet, haha.)
Anyway, thank you for kindly and thoroughly explaining everything. I’ll definitely give it a try!