Liger-Kernel icon indicating copy to clipboard operation
Liger-Kernel copied to clipboard

Triton error on AMD GPUs

Open eminorhan opened this issue 1 year ago • 8 comments
trafficstars

🐛 Describe the bug

I'm trying to test this library on an HPC cluster with AMD MI250X GPUs, but I'm getting a weird seemingly Triton-related error specifically when I turn on model.train(). The following is a minimal example:

import torch
from transformers import AutoModelForCausalLM
from liger_kernel.transformers import apply_liger_kernel_to_llama

apply_liger_kernel_to_llama()
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B", attn_implementation='sdpa', torch_dtype=torch.bfloat16)

x = torch.zeros((4, 128), dtype=int)
x = x.cuda()
model = model.cuda()

model.train()  # runs without an issue when I comment out this line 

y = model(input_ids=x, labels=x)

I get the following error when I run this on an MI250X:

  File "/lustre/orion/stf218/scratch/emin/test_liger/test.py", line 22, in <module>
    y = model(input_ids=x, labels=x)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/liger_kernel/transformers/model/llama.py", line 109, in lce_forward
    loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/liger_kernel/transformers/fused_linear_cross_entropy.py", line 13, in forward
    return LigerFusedLinearCrossEntropyFunction.apply(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/torch/autograd/function.py", line 575, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/liger_kernel/ops/fused_linear_cross_entropy.py", line 193, in forward
    loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
                                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/liger_kernel/ops/fused_linear_cross_entropy.py", line 73, in fused_linear_cross_entropy_forward
    liger_cross_entropy_kernel[(n_rows,)](
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/triton/runtime/jit.py", line 691, in run
    kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/triton/backends/amd/driver.py", line 418, in __call__
    self.launch(*args, **kwargs)
RuntimeError: Triton Error [HIP]:  Code: 1, Messsage: invalid argument

The same code snippet works fine when I turn off model.train(). I also have access to another cluster with NVIDIA GPUs and I can confirm that it works fine (with or without model.train()) on NVIDIA GPUs (A100 and H100), so this is an AMD-specific issue. I would appreciate any help you could provide for debugging this issue.

Reproduce

No response

Versions

I'm running on PyTorch-nightly + ROCm 6.2 + liger-kernel-nightly:

torch==2.5.0.dev20240906+rocm6.2
triton==3.0.0
liger-kernel-nightly==0.2.1.dev20240908014422
transformers==4.44.2

eminorhan avatar Sep 08 '24 06:09 eminorhan