Liger-Kernel icon indicating copy to clipboard operation
Liger-Kernel copied to clipboard

Tensor Parallel support

Open winglian opened this issue 4 months ago • 1 comments

🐛 Describe the bug

When enabling Tensor Parallelism with training, I get the following new error:

[rank1]:   File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 287, in forward                                                                        
[rank1]:     hidden_states = self.input_layernorm(hidden_states)                                                                                                                                                     
[rank1]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                     [rank1]:   File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl                                                                             
[rank1]:     return self._call_impl(*args, **kwargs)                                                                                                                                                                 
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                 [rank1]:   File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1857, in _call_impl                                                                                     
[rank1]:     return inner()                                                                                                                                                                                          
[rank1]:            ^^^^^^^                                                                                                                                                                                          
[rank1]:   File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1805, in inner                                                                                          
[rank1]:     result = forward_call(*args, **kwargs)                                                                                                                                                                  
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                  
[rank1]:   File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/liger_kernel/transformers/rms_norm.py", line 33, in forward                                                                               
[rank1]:     return LigerRMSNormFunction.apply(                                                                                                                                                                      
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                      
[rank1]:   File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/autograd/function.py", line 575, in apply                                                                                           
[rank1]:     return super().apply(*args, **kwargs)  # type: ignore[misc]                                                                                                                                             
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                   
[rank1]:   File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/liger_kernel/ops/utils.py", line 40, in wrapper                                                                                           
[rank1]:     return fn(ctx, *args, **kwargs)                                                                                                                                                                         
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                         
[rank1]:   File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/liger_kernel/ops/rms_norm.py", line 556, in forward                                                                                       
[rank1]:     Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)                                                                                           
[rank1]:                                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                           
[rank1]:   File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/liger_kernel/ops/rms_norm.py", line 400, in rms_norm_forward                                                                              
[rank1]:     _rms_norm_forward_kernel[(n_rows,)](                                                                                                                                                                    
[rank1]:   File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/triton/runtime/jit.py", line 347, in <lambda>                                                                                             
[rank1]:     return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)                                                                                                                       
[rank1]:                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                       
[rank1]:   File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/triton/runtime/jit.py", line 591, in run                                                                                                  
[rank1]:     kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata,                                                                                                                     
[rank1]:   File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/triton/backends/nvidia/driver.py", line 529, in __call__                                                                                  
[rank1]:     self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, global_scratch, *args)                                                                                                 
[rank1]: RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered                                                                                                                                 

Reproduce

No response

Versions

Environment Report:

Operating System: Linux-6.8.0-59-generic-x86_64-with-glibc2.35 Python version: 3.11.13 Liger Kernel version: 0.6.0 PyTorch version: 2.7.1+cu126 CUDA version: 12.6 HIP(ROCm) version: Not available Triton version: 3.3.1 Transformers version: 4.53.2 XPU version: XPU Not Available

winglian avatar Jul 20 '25 20:07 winglian