Liger-Kernel icon indicating copy to clipboard operation
Liger-Kernel copied to clipboard

Triton error on AMD GPUs

Open eminorhan opened this issue 1 year ago • 8 comments

🐛 Describe the bug

I'm trying to test this library on an HPC cluster with AMD MI250X GPUs, but I'm getting a weird seemingly Triton-related error specifically when I turn on model.train(). The following is a minimal example:

import torch
from transformers import AutoModelForCausalLM
from liger_kernel.transformers import apply_liger_kernel_to_llama

apply_liger_kernel_to_llama()
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B", attn_implementation='sdpa', torch_dtype=torch.bfloat16)

x = torch.zeros((4, 128), dtype=int)
x = x.cuda()
model = model.cuda()

model.train()  # runs without an issue when I comment out this line 

y = model(input_ids=x, labels=x)

I get the following error when I run this on an MI250X:

  File "/lustre/orion/stf218/scratch/emin/test_liger/test.py", line 22, in <module>
    y = model(input_ids=x, labels=x)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/liger_kernel/transformers/model/llama.py", line 109, in lce_forward
    loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/liger_kernel/transformers/fused_linear_cross_entropy.py", line 13, in forward
    return LigerFusedLinearCrossEntropyFunction.apply(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/torch/autograd/function.py", line 575, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/liger_kernel/ops/fused_linear_cross_entropy.py", line 193, in forward
    loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
                                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/liger_kernel/ops/fused_linear_cross_entropy.py", line 73, in fused_linear_cross_entropy_forward
    liger_cross_entropy_kernel[(n_rows,)](
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/triton/runtime/jit.py", line 691, in run
    kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
  File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/triton/backends/amd/driver.py", line 418, in __call__
    self.launch(*args, **kwargs)
RuntimeError: Triton Error [HIP]:  Code: 1, Messsage: invalid argument

The same code snippet works fine when I turn off model.train(). I also have access to another cluster with NVIDIA GPUs and I can confirm that it works fine (with or without model.train()) on NVIDIA GPUs (A100 and H100), so this is an AMD-specific issue. I would appreciate any help you could provide for debugging this issue.

Reproduce

No response

Versions

I'm running on PyTorch-nightly + ROCm 6.2 + liger-kernel-nightly:

torch==2.5.0.dev20240906+rocm6.2
triton==3.0.0
liger-kernel-nightly==0.2.1.dev20240908014422
transformers==4.44.2

eminorhan avatar Sep 08 '24 06:09 eminorhan

cc @jokeren do you have any ideas?

ByronHsu avatar Sep 08 '24 06:09 ByronHsu

same error as https://github.com/triton-lang/triton/issues/4128. also, @helloworld1 has tested on AMD GPUs before. Can you share your experience?

ByronHsu avatar Sep 08 '24 06:09 ByronHsu

Same error on MI210, not able to resolve it myself. Looks like triton / rocm compatibility issue.

helloworld1 avatar Sep 08 '24 15:09 helloworld1

@ByronHsu Thanks a lot for the pointer. I really appreciate the help. So, when I manually change the num_warps arguments to 64 in fused_linear_cross_entropy.py, it seems to fix this particular issue, but now I get another error in fused_linear_cross_entropy:

[rank16]:   File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/liger_kernel/ops/fused_linear_cross_entropy.py", line 193, in forward
[rank16]:     loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
[rank16]:                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank16]:   File "/lustre/orion/stf218/scratch/emin/miniconda3/lib/python3.12/site-packages/liger_kernel/ops/fused_linear_cross_entropy.py", line 54, in fused_linear_cross_entropy_forward
[rank16]:     logits_chunk = _input_chunk @ weight.t()  # chunk_size x V
[rank16]:                    ~~~~~~~~~~~~~^~~~~~~~~~~~
[rank16]: RuntimeError: size mismatch, got input (16), mat (16x4096), vec (16416768)

I should note that this doesn't use the minimal code above. It's a little bit more complicated with an fsdp wrapper around the model (I couldn't immediately create a minimal example), but I was wondering if you had any ideas as to what might be triggering this size mismatch error.

eminorhan avatar Sep 08 '24 16:09 eminorhan

I got this same error a few weeks ago trying to train on MI300x with axolotl (on torch 2.4.0+rocm6.1). There was one time I got the training run to start fiddling with various deps, but I could never reproduce that unfortunately.

EDIT: Was able to get rope and rms_norm liger kernels to run without this error for a Llama 3.1 model on the setup I mentioned above. swiglu, cross entropy, and fused linear cross entropy all result in this error, in case that helps narrow anything down a little.

DocShotgun avatar Sep 12 '24 17:09 DocShotgun

I don't maintain the AMD backend. Better to try out triton/main or contact AMD people

Jokeren avatar Sep 12 '24 18:09 Jokeren

Following the logic in the issue linked here (https://github.com/linkedin/Liger-Kernel/issues/231#issuecomment-2336569380), noting that the warp size of AMD Instinct processors is 64 compared to 32 for NVIDIA GPUs, I halved num_warps across the board in my fork of liger kernel (https://github.com/DocShotgun/Liger-Kernel/commit/81db02cff944104dc7a28670211417af0577912d). This appears to solve the problem for me, allowing training on MI300x on a llama 3.1 8b architecture model.

Training appears to be working fine judging by my logs (slightly faster and significantly less memory while having similar loss and grad norms):

No liger:

{'loss': 1.3565, 'grad_norm': 26.5, 'learning_rate': 2.5000000000000004e-07, 'epoch': 0.0}                                                                                        
[axolotl.callbacks.on_step_end:128] [PID:8168] [RANK:0] GPU memory usage while training: 29.945GB (+50.608GB cache)
{'loss': 1.3419, 'grad_norm': 29.25, 'learning_rate': 5.000000000000001e-07, 'epoch': 0.01}                                                                                       
{'loss': 1.3337, 'grad_norm': 25.625, 'learning_rate': 7.5e-07, 'epoch': 0.01}                                                                                                    
{'loss': 1.3301, 'grad_norm': 25.875, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.02}                                                                                     
{'loss': 1.2617, 'grad_norm': 24.375, 'learning_rate': 1.25e-06, 'epoch': 0.02}                                                                                                   
{'loss': 1.2703, 'grad_norm': 22.75, 'learning_rate': 1.5e-06, 'epoch': 0.02}                                                                                                     
{'loss': 1.2243, 'grad_norm': 21.375, 'learning_rate': 1.75e-06, 'epoch': 0.03}                                                                                                   
{'loss': 1.3376, 'grad_norm': 19.0, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.03}                                                                                       
{'loss': 1.287, 'grad_norm': 15.125, 'learning_rate': 2.25e-06, 'epoch': 0.04}                                                                                                    
{'loss': 1.2013, 'grad_norm': 10.0, 'learning_rate': 2.5e-06, 'epoch': 0.04}                                                                                                      
{'loss': 1.1887, 'grad_norm': 8.375, 'learning_rate': 2.7500000000000004e-06, 'epoch': 0.04}                                                                                      
  2%|███                                                                                                                                       | 11/506 [02:10<1:36:51, 11.74s/it]

With liger:

{'loss': 1.3556, 'grad_norm': 26.375, 'learning_rate': 2.5000000000000004e-07, 'epoch': 0.0}                                                                                      
[axolotl.callbacks.on_step_end:128] [PID:8571] [RANK:0] GPU memory usage while training: 29.948GB (+24.158GB cache)
{'loss': 1.3414, 'grad_norm': 29.25, 'learning_rate': 5.000000000000001e-07, 'epoch': 0.01}                                                                                       
{'loss': 1.334, 'grad_norm': 25.625, 'learning_rate': 7.5e-07, 'epoch': 0.01}                                                                                                     
{'loss': 1.3305, 'grad_norm': 26.375, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.02}                                                                                     
{'loss': 1.262, 'grad_norm': 24.25, 'learning_rate': 1.25e-06, 'epoch': 0.02}                                                                                                     
{'loss': 1.2709, 'grad_norm': 23.0, 'learning_rate': 1.5e-06, 'epoch': 0.02}                                                                                                      
{'loss': 1.2239, 'grad_norm': 21.25, 'learning_rate': 1.75e-06, 'epoch': 0.03}                                                                                                    
{'loss': 1.3373, 'grad_norm': 18.375, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.03}                                                                                     
{'loss': 1.2871, 'grad_norm': 14.8125, 'learning_rate': 2.25e-06, 'epoch': 0.04}                                                                                                  
{'loss': 1.2012, 'grad_norm': 10.125, 'learning_rate': 2.5e-06, 'epoch': 0.04}                                                                                                    
{'loss': 1.1885, 'grad_norm': 8.25, 'learning_rate': 2.7500000000000004e-06, 'epoch': 0.04}                                                                                       
  2%|███                                                                                                                                       | 11/506 [01:54<1:24:52, 10.29s/it]

I know nothing about triton kernels, so I wanted to ask if there are any potential adverse consequences to this? And if not, would it be sufficient to simply correct the num_warps by half when an AMD Instinct processor is detected?

DocShotgun avatar Sep 16 '24 19:09 DocShotgun

Ha! I had set num_warps=64, but I think you're right that it should have been 16 instead (I mixed up num_warps with warp_size).

eminorhan avatar Sep 16 '24 19:09 eminorhan