Liger-Kernel
Liger-Kernel copied to clipboard
Apple's cross entropy computation
Hi thanks for the library! Today I see a paper https://openreview.net/forum?id=E4Fk3YuG56 (code: https://github.com/apple/ml-cross-entropy), which seems to discuss a way to compute cross entropy. Thus I share this here in case it is useful for this repository.
Hi @fzyzcjy , did you try the original repo? Does it work as expected?
Hi, no I have not tried it yet
I ran into problems using/installing Apple's kernel from that repo. I assume that the triton etc. that they did is sound and does what it says it does, but it's just research code and isn't well tested for many versions/platforms. Would be amazing to have it be part of liger-kernel, because everything here is well tested and "just works" out of the box.
We are more than happy to host and maintain innovative kernels like https://github.com/apple/ml-cross-entropy. @erikwijmans are you interested in collaboration? we are committed to long-term maintenance at the company level
FYI @ByronHsu this is a very simple reproduction of why the cce kernel isn't working for me. I know you use Modal for CI so you should pretty easily able to reproduce this. I wish I could debug it myself but I am not a Triton god like you. This is the most simple setup I can think of: fresh install of only the cut cross entropy package, attempt to import linear_cross_entropy, error happens. I did not make any modifications to the code.
This uses python version 3.10, triton==3.1.0, torch==2.5.1.
import modal
image = modal.Image.debian_slim(python_version="3.10").apt_install("git").pip_install(
"cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git"
)
app = modal.App("cut-cross-entropy")
@app.function(
image=image,
gpu=modal.gpu.A10G()
)
def test_cce():
from cut_cross_entropy import linear_cross_entropy
print("success!")
Backtrace:
Traceback (most recent call last):
File "/root/test.py", line 14, in test_cce
from cut_cross_entropy import linear_cross_entropy
File "/usr/local/lib/python3.10/site-packages/cut_cross_entropy/__init__.py", line 2, in <module>
from cut_cross_entropy.linear_cross_entropy import (
File "/usr/local/lib/python3.10/site-packages/cut_cross_entropy/linear_cross_entropy.py", line 20, in <module>
from cut_cross_entropy.cce import cce_linear_cross_entropy
File "/usr/local/lib/python3.10/site-packages/cut_cross_entropy/cce.py", line 7, in <module>
from cut_cross_entropy.cce_backward import cce_backward_kernel
File "/usr/local/lib/python3.10/site-packages/cut_cross_entropy/cce_backward.py", line 81, in <module>
def _cce_backward_kernel(
File "/usr/local/lib/python3.10/site-packages/triton/runtime/jit.py", line 882, in jit
return decorator(fn)
File "/usr/local/lib/python3.10/site-packages/triton/runtime/jit.py", line 871, in decorator
return JITFunction(
File "/usr/local/lib/python3.10/site-packages/triton/runtime/jit.py", line 717, in __init__
self.src = self.src[re.search(r"^def\s+\w+\s*\(", self.src, re.MULTILINE).start():]
AttributeError: 'NoneType' object has no attribute 'start'
OK, turns out the problem is something related to triton's regexp search for the source code + applying multiple decorators. The fix is to comment out the decorators on cce_backward_kernel:
# @cce_backward_autotune()
# @triton.heuristics(
# {
# "EVEN_D": lambda args: (args["D"] % args["BLOCK_D"]) == 0,
# "MM_BACK_BLOCK_D": lambda args: args["BLOCK_D"] * 2,
# "MM_BACK_EVEN_D": lambda args: (args["D"] % (args["BLOCK_D"] * 2)) == 0,
# "HAS_VALIDS": lambda args: args["Valids"] is not None,
# "HAS_VOCAB_ORDERING": lambda args: args["VocabOrdering"] is not None,
# "FILTER_GRAD": lambda args: args["filter_eps"] is not None,
# "HAS_TARGETS": lambda args: args["Targets"] is not None,
# "HAS_SOFTCAP": lambda args: args["softcap"] is not None,
# "ITEM_DO": lambda args: args["dOut"].numel() == 1,
# "GROUP_B": lambda args: 8,
# }
# )
# @triton.jit
def _cce_backward_kernel(
...and instead apply them "manually" like this:
_cce_backward_kernel = triton.jit(_cce_backward_kernel)
_cce_backward_kernel = triton.heuristics(
{
"EVEN_D": lambda args: (args["D"] % args["BLOCK_D"]) == 0,
"MM_BACK_BLOCK_D": lambda args: args["BLOCK_D"] * 2,
"MM_BACK_EVEN_D": lambda args: (args["D"] % (args["BLOCK_D"] * 2)) == 0,
"HAS_VALIDS": lambda args: args["Valids"] is not None,
"HAS_VOCAB_ORDERING": lambda args: args["VocabOrdering"] is not None,
"FILTER_GRAD": lambda args: args["filter_eps"] is not None,
"HAS_TARGETS": lambda args: args["Targets"] is not None,
"HAS_SOFTCAP": lambda args: args["softcap"] is not None,
"ITEM_DO": lambda args: args["dOut"].numel() == 1,
"GROUP_B": lambda args: 8,
}
)(_cce_backward_kernel)
_cce_backward_kernel = cce_backward_autotune()(_cce_backward_kernel)
This has been merged to https://github.com/apple/ml-cross-entropy Any update on CCE integration?
I got CCE working with transformers but it was a hacked mess. The unsloth guys just announced that its now supported in their new blog post here.
here is the patch of the forward: https://github.com/unslothai/unsloth/blob/main/unsloth/models/llama.py#L932
Here is where they bring in CCE https://github.com/unslothai/unsloth-zoo/blob/main/unsloth_zoo/loss_utils.py#L139
I am not an expert on the kernel side but can help with the integration.