Liger-Kernel icon indicating copy to clipboard operation
Liger-Kernel copied to clipboard

Apple's cross entropy computation

Open fzyzcjy opened this issue 1 year ago • 8 comments
trafficstars

Hi thanks for the library! Today I see a paper https://openreview.net/forum?id=E4Fk3YuG56 (code: https://github.com/apple/ml-cross-entropy), which seems to discuss a way to compute cross entropy. Thus I share this here in case it is useful for this repository.

fzyzcjy avatar Nov 17 '24 10:11 fzyzcjy

Hi @fzyzcjy , did you try the original repo? Does it work as expected?

leng-yue avatar Dec 03 '24 03:12 leng-yue

Hi, no I have not tried it yet

fzyzcjy avatar Dec 03 '24 03:12 fzyzcjy

I ran into problems using/installing Apple's kernel from that repo. I assume that the triton etc. that they did is sound and does what it says it does, but it's just research code and isn't well tested for many versions/platforms. Would be amazing to have it be part of liger-kernel, because everything here is well tested and "just works" out of the box.

andersonbcdefg avatar Dec 05 '24 21:12 andersonbcdefg

We are more than happy to host and maintain innovative kernels like https://github.com/apple/ml-cross-entropy. @erikwijmans are you interested in collaboration? we are committed to long-term maintenance at the company level

ByronHsu avatar Dec 05 '24 22:12 ByronHsu

FYI @ByronHsu this is a very simple reproduction of why the cce kernel isn't working for me. I know you use Modal for CI so you should pretty easily able to reproduce this. I wish I could debug it myself but I am not a Triton god like you. This is the most simple setup I can think of: fresh install of only the cut cross entropy package, attempt to import linear_cross_entropy, error happens. I did not make any modifications to the code.

This uses python version 3.10, triton==3.1.0, torch==2.5.1.

import modal

image = modal.Image.debian_slim(python_version="3.10").apt_install("git").pip_install(
    "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git"
)

app = modal.App("cut-cross-entropy")

@app.function(
    image=image,
    gpu=modal.gpu.A10G()
)
def test_cce():
    from cut_cross_entropy import linear_cross_entropy
    print("success!")

Backtrace:

Traceback (most recent call last):
  File "/root/test.py", line 14, in test_cce
    from cut_cross_entropy import linear_cross_entropy
  File "/usr/local/lib/python3.10/site-packages/cut_cross_entropy/__init__.py", line 2, in <module>
    from cut_cross_entropy.linear_cross_entropy import (
  File "/usr/local/lib/python3.10/site-packages/cut_cross_entropy/linear_cross_entropy.py", line 20, in <module>
    from cut_cross_entropy.cce import cce_linear_cross_entropy
  File "/usr/local/lib/python3.10/site-packages/cut_cross_entropy/cce.py", line 7, in <module>
    from cut_cross_entropy.cce_backward import cce_backward_kernel
  File "/usr/local/lib/python3.10/site-packages/cut_cross_entropy/cce_backward.py", line 81, in <module>
    def _cce_backward_kernel(
  File "/usr/local/lib/python3.10/site-packages/triton/runtime/jit.py", line 882, in jit
    return decorator(fn)
  File "/usr/local/lib/python3.10/site-packages/triton/runtime/jit.py", line 871, in decorator
    return JITFunction(
  File "/usr/local/lib/python3.10/site-packages/triton/runtime/jit.py", line 717, in __init__
    self.src = self.src[re.search(r"^def\s+\w+\s*\(", self.src, re.MULTILINE).start():]
AttributeError: 'NoneType' object has no attribute 'start'

andersonbcdefg avatar Dec 07 '24 02:12 andersonbcdefg

OK, turns out the problem is something related to triton's regexp search for the source code + applying multiple decorators. The fix is to comment out the decorators on cce_backward_kernel:

# @cce_backward_autotune()
# @triton.heuristics(
#     {
#         "EVEN_D": lambda args: (args["D"] % args["BLOCK_D"]) == 0,
#         "MM_BACK_BLOCK_D": lambda args: args["BLOCK_D"] * 2,
#         "MM_BACK_EVEN_D": lambda args: (args["D"] % (args["BLOCK_D"] * 2)) == 0,
#         "HAS_VALIDS": lambda args: args["Valids"] is not None,
#         "HAS_VOCAB_ORDERING": lambda args: args["VocabOrdering"] is not None,
#         "FILTER_GRAD": lambda args: args["filter_eps"] is not None,
#         "HAS_TARGETS": lambda args: args["Targets"] is not None,
#         "HAS_SOFTCAP": lambda args: args["softcap"] is not None,
#         "ITEM_DO": lambda args: args["dOut"].numel() == 1,
#         "GROUP_B": lambda args: 8,
#     }
# )
# @triton.jit
def _cce_backward_kernel(

...and instead apply them "manually" like this:

_cce_backward_kernel = triton.jit(_cce_backward_kernel)
_cce_backward_kernel = triton.heuristics(
    {
        "EVEN_D": lambda args: (args["D"] % args["BLOCK_D"]) == 0,
        "MM_BACK_BLOCK_D": lambda args: args["BLOCK_D"] * 2,
        "MM_BACK_EVEN_D": lambda args: (args["D"] % (args["BLOCK_D"] * 2)) == 0,
        "HAS_VALIDS": lambda args: args["Valids"] is not None,
        "HAS_VOCAB_ORDERING": lambda args: args["VocabOrdering"] is not None,
        "FILTER_GRAD": lambda args: args["filter_eps"] is not None,
        "HAS_TARGETS": lambda args: args["Targets"] is not None,
        "HAS_SOFTCAP": lambda args: args["softcap"] is not None,
        "ITEM_DO": lambda args: args["dOut"].numel() == 1,
        "GROUP_B": lambda args: 8,
    }
)(_cce_backward_kernel)
_cce_backward_kernel = cce_backward_autotune()(_cce_backward_kernel)

andersonbcdefg avatar Dec 07 '24 03:12 andersonbcdefg

This has been merged to https://github.com/apple/ml-cross-entropy Any update on CCE integration?

ccdv-ai avatar Dec 10 '24 23:12 ccdv-ai

I got CCE working with transformers but it was a hacked mess. The unsloth guys just announced that its now supported in their new blog post here.

here is the patch of the forward: https://github.com/unslothai/unsloth/blob/main/unsloth/models/llama.py#L932

Here is where they bring in CCE https://github.com/unslothai/unsloth-zoo/blob/main/unsloth_zoo/loss_utils.py#L139

I am not an expert on the kernel side but can help with the integration.

amazingvince avatar Dec 11 '24 15:12 amazingvince