Liger-Kernel
Liger-Kernel copied to clipboard
LigerGEGLUMLP error with torch.compile
🐛 Describe the bug
When using ligergeglumlp with torch complie i get the following error.
UserWarning: Traceback (most recent call last):
Encountered an exception in identify_mutated_tensors, assuming every input is mutated:
File "~/.local/lib/python3.10/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py", line 595, in identify_mutated_tensors
ttir_module, ordered_tensor_names = generate_ttir(kernel, kwargs)
Encountered an exception in identify_mutated_tensors, assuming every input is mutated:
File "~/.local/lib/python3.10/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py", line 115, in generate_ttir
raise Exception("Incorrect number of arguments passed to kernel")
Encountered an exception in identify_mutated_tensors, assuming every input is mutated:
Exception: Incorrect number of arguments passed to kernel
BackendCompilerFailed: backend='inductor' raised:
CompilationError: at 21:18:
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
b_row = tl.load(b + col_offsets, mask=mask, other=0)
# tanh approximation form of GELU is computed with:
# 0.5 * a * (1 + tanh(sqrt(2 / pi) * (a + 0.044715 * a^3)))
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
a_cubed = a_row * a_row * a_row
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
tanh_result = tanh(tanh_arg)
^
NameError('tanh is not defined')
I can run it with the torch.compile line commented out with no problem.
Reproduce
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from dataclasses import dataclass
import torch
torch.set_float32_matmul_precision('high')
@dataclass
class Config:
hidden_size: int = 768
intermediate_size: int = 768 * 4
hidden_act: str = 'gelu_pytorch_tanh'
cfg = Config()
gegelu = LigerGEGLUMLP(cfg).cuda()
gegelu = torch.compile(gegelu)
x = torch.randn(1,1024,768,device='cuda')
Versions
Python Version: 3.10.12 CUDA Version: 12.1 PyTorch Version: 2.3.0+cu121 Triton Version: 2.3.0
hi @Luke-Chesley thanks for reporting! Are you using the nightly version of liger or 0.1.1? Asking since i'm not able to reproduce it w/ 0.1.1, using same torch & triton (except cu version)
jobuser [ ~/resources ]$ pip show torch
piName: torch
Version: 2.3.0+cu118
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: [email protected]
License: BSD-3
Location: /home/jobuser/.local/lib/python3.10/site-packages
Requires: filelock, fsspec, jinja2, networkx, nvidia-cublas-cu11, nvidia-cuda-cupti-cu11, nvidia-cuda-nvrtc-cu11, nvidia-cuda-runtime-cu11, nvidia-cudnn-cu11, nvidia-cufft-cu11, nvidia-curand-cu11, nvidia-cusolver-cu11, nvidia-cusparse-cu11, nvidia-nccl-cu11, nvidia-nvtx-cu11, sympy, triton, typing-extensions
Required-by: accelerate, ella, flash-attn, liger-kernel, lightning, pytorch-lightning, torchmetrics
pjobuser [ ~/resources ]$ pip show triton
Name: triton
Version: 2.3.0
Summary: A language and compiler for custom Deep Learning operations
Home-page: https://github.com/openai/triton/
Author: Philippe Tillet
Author-email: [email protected]
License:
Location: /home/jobuser/.local/lib/python3.10/site-packages
Requires: filelock
Required-by: ella, liger-kernel, torch
jobuser [ ~/resources ]$ python test.py
jobuser [ ~/resources ]$ echo $?
0
I'm using 0.1.1. I upgraded triton to 3.0.0 and it works, but still not on 2.3.0.
hmm tried to reproduce on colab w/ T4 GPU instead of my own env but still couldn't reproduce that https://colab.research.google.com/drive/1O6ySBz0K9_73oztAzOLyzELaC3TM2AhS#scrollTo=L4_b1U34zmdf
@Luke-Chesley could you try to run from triton.language.math import tanh in your triton 2.3 env and tell us what's the output?
hi @yundai424, I can run from triton.language.math import tanh in the env without errors. I'll try to include as much local venv info here as I can.
I was able to reproduce it here in colab on T4. https://colab.research.google.com/drive/1wzGW8wY1b3rIUyiS4vFpJ59w88PqI9ii?usp=sharing
OS: Linux-6.8.0-40-generic-x86_64-with-glibc2.35
Python Version: 3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0]
GPU: NVIDIA GeForce RTX 3070
pip freeze
asttokens==2.4.1
certifi==2024.7.4
charset-normalizer==3.3.2+
comm==0.2.2
debugpy==1.8.5
decorator==5.1.1
exceptiongroup==1.2.2
executing==2.0.1
filelock==3.15.4
fsspec==2024.6.1
huggingface-hub==0.24.6
idna==3.8
ipykernel==6.29.5
ipython==8.26.0
jedi==0.19.1
Jinja2==3.1.4
jupyter_client==8.6.2
jupyter_core==5.7.2
liger-kernel==0.1.1
MarkupSafe==2.1.5
matplotlib-inline==0.1.7
mpmath==1.3.0
nest-asyncio==1.6.0
networkx==3.3
numpy==2.1.0
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.6.20
nvidia-nvtx-cu12==12.1.105
packaging==24.1
parso==0.8.4
pexpect==4.9.0
platformdirs==4.2.2
prompt_toolkit==3.0.47
psutil==6.0.0
ptyprocess==0.7.0
pure_eval==0.2.3
Pygments==2.18.0
python-dateutil==2.9.0.post0
PyYAML==6.0.2
pyzmq==26.2.0
regex==2024.7.24
requests==2.32.3
safetensors==0.4.4
six==1.16.0
stack-data==0.6.3
sympy==1.13.2
tiktoken==0.7.0
tokenizers==0.19.1
torch==2.3.0
tornado==6.4.1
tqdm==4.66.5
traitlets==5.14.3
transformers==4.44.2
triton==2.3.0
typing_extensions==4.12.2
urllib3==2.2.2
wcwidth==0.2.13
Hi @Luke-Chesley requested access to the colab. Thanks!
Minimum reproducible script to narrow it down:
import torch
from triton.language.math import tanh
import triton.language as tl
import triton
@triton.jit
def _triton_tanh(A_ptr, B_ptr, stride, n_cols, BLOCK_SIZE: tl.constexpr):
program_id = tl.program_id(0)
A_ptr += program_id * stride
B_ptr += program_id * stride
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
a_row = tl.load(A_ptr + col_offsets, mask=mask, other=0)
tanh_res = tanh(a_row)
tl.store(B_ptr + col_offsets, tanh(a_row), mask=mask)
class CustomTanhFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, a):
b = torch.zeros_like(a)
BLOCK_SIZE=1024
_triton_tanh[(a.shape[0],)](a, b, a.stride(0), n_cols=a.shape[-1], BLOCK_SIZE=BLOCK_SIZE)
return b
x = torch.randn((4,4)).to('cuda')
torch.compile(CustomTanhFunction.apply)(x)
I'll take a look at inductor to see what's happening there. In the meantime please feel free to either manually turn off geglu in the patch API, or add below lines:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
to work around the issue. Thanks a lot!
still a problem?