Liger-Kernel
Liger-Kernel copied to clipboard
Triton error on AMD GPUs
trafficstars
🐛 Describe the bug
I'm trying to test this library on an HPC cluster with AMD MI250X GPUs, but I'm getting a weird seemingly Triton-related error specifically when I turn on model.train(). The following is a minimal example:
import torch
from transformers import AutoModelForCausalLM
from liger_kernel.transformers import apply_liger_kernel_to_llama
apply_liger_kernel_to_llama()
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B", attn_implementation='sdpa', torch_dtype=torch.bfloat16)
x = torch.zeros((4, 128), dtype=int)
x = x.cuda()
model = model.cuda()
model.train() # runs without an issue when I comment out this line
y = model(input_ids=x, labels=x)
I get the following error when I run this on an MI250X:
File "/lustre/orion/stf218/scratch/emin/test_liger/test.py", line 22, in <module>
y = model(input_ids=x, labels=x)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/liger_kernel/transformers/model/llama.py", line 109, in lce_forward
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/liger_kernel/transformers/fused_linear_cross_entropy.py", line 13, in forward
return LigerFusedLinearCrossEntropyFunction.apply(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/torch/autograd/function.py", line 575, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/liger_kernel/ops/fused_linear_cross_entropy.py", line 193, in forward
loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/liger_kernel/ops/fused_linear_cross_entropy.py", line 73, in fused_linear_cross_entropy_forward
liger_cross_entropy_kernel[(n_rows,)](
File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/triton/runtime/jit.py", line 345, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/triton/runtime/jit.py", line 691, in run
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/triton/backends/amd/driver.py", line 418, in __call__
self.launch(*args, **kwargs)
RuntimeError: Triton Error [HIP]: Code: 1, Messsage: invalid argument
The same code snippet works fine when I turn off model.train(). I also have access to another cluster with NVIDIA GPUs and I can confirm that it works fine (with or without model.train()) on NVIDIA GPUs (A100 and H100), so this is an AMD-specific issue. I would appreciate any help you could provide for debugging this issue.
Reproduce
No response
Versions
I'm running on PyTorch-nightly + ROCm 6.2 + liger-kernel-nightly:
torch==2.5.0.dev20240906+rocm6.2
triton==3.0.0
liger-kernel-nightly==0.2.1.dev20240908014422
transformers==4.44.2