Liger-Kernel icon indicating copy to clipboard operation
Liger-Kernel copied to clipboard

[feat] FusedLinearCrossEntropy support for Gemma2

Open yundai424 opened this issue 1 year ago • 4 comments

🚀 The feature, motivation and pitch

FLCE needs special handling for the soft capping in gemma2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma2/modeling_gemma2.py#L1054

Alternatives

No response

Additional context

No response

yundai424 avatar Aug 27 '24 20:08 yundai424

#take @yundai424 I would like to make an attempt to make it available.

I'm thinking this approach:

  • Introduce an optional dict parameter in forward() method here with softcap key value and non linearity key value(tanh in case of gemma2)
  • Peform the 3 steps from modelling_gemma2 after this matmul operation if optional dict parameter is not None

Can you assign it to me if this sounds okay?

troy1729 avatar Aug 28 '24 06:08 troy1729

@troy1729 Sounds reasonable to me. Assigned and feel free to kick off the implementation and ping us to discuss or review on any issues. Thank you!

qingquansong avatar Aug 28 '24 08:08 qingquansong

Hi @qingquansong, I've made the changes but still have to add the tests hence kept the PR in draft stage. Might be a silly question, but we would want to have a triton kernel implementation for tanh/(any other non linearity) isn't it? Right now I've added torch.tanh callable. I'm sorry if this is obvious but thought to clarify

troy1729 avatar Aug 29 '24 06:08 troy1729

Hey @troy1729 , thanks for the question (no silly question) and fast kick off! I think

  1. having certain triton functions operated on single element/block would be good in certain cases such as the silu function we have for swiglu that can be fused and used with other operations. Since in the end, we'd like to reduce element-wise operation overhead (like geglu/swiglu or Relu+ matmul etc.) rather than calling single one directly which will be same as calling torch.tanh especially after the torch compile. Also, check my comment 3 here and you'll find that implementing a single activation kernel would not be super helpful for you to fuse it with other operations especially in the backward pass. (isolated foward/backward functions could be helpful though)

  2. The soft cap idea is mainly scaling + caping range to (-1, 1) so using tanh (which keeps both pos and neg info) so some other torch activations may not be good to use here (though I agree we may have some cases in the future that possibly call extra torch activation functions)

  3. You may want to think about how the backprop is computed give this activation added on the logits. Since it's not as straight forward as just adding this activation, but you'll need to compute the grad_input (which is the grad of the hidden states) and the grad_weights.

  4. One more option is to put this option inside the liger normal CE loss (also need to take care of the backward if enabled this option) and then outside the chunked calling of the kernel, in the flce kernel, you don't need to worry about the backprop.

In sum, my suggestion would be: implement the tanh option for now only + follow geglu backward to see how tanh gradient is computed with chain rule to device the equation and implement it here

qingquansong avatar Aug 29 '24 07:08 qingquansong

I believe I've implemented softcap in cross entropy function correctly and the flce support for gemma2. But since gemma2 currently can't pass the test even without flce, do I need to find a way to pass the relevant convergence test (test_mini_models_no_logits.py)? cc @yundai424

Tcc0403 avatar Oct 22 '24 20:10 Tcc0403