Liger-Kernel
Liger-Kernel copied to clipboard
fix: don't drop kwargs from huggingface forward
Summary
HuggingFace forward passes kwargs through: https://github.com/huggingface/transformers/blob/716819b8309324302e00a3488a3c3d6faa427f79/src/transformers/models/qwen2/modeling_qwen2.py#L712
This is important to compute FlashAttention kwargs outside of the forward, so that it's not recomputed on every attention layer, which causes a number of issues: https://github.com/huggingface/transformers/issues/35588
Testing Done
- Hardware Type: H100
- [ ] run
make testto ensure correctness - [x] run
make checkstyleto ensure code style - [ ] run
make test-convergenceto ensure convergence
LGTM, @llllvvuu as soon as the merge conflict is resolved we can get this in
LGTM, @llllvvuu as soon as the merge conflict is resolved we can get this in
Done