Liger-Kernel
Liger-Kernel copied to clipboard
Tensor Parallel support
🐛 Describe the bug
When enabling Tensor Parallelism with training, I get the following new error:
[rank1]: File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 287, in forward
[rank1]: hidden_states = self.input_layernorm(hidden_states)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank1]: File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank1]: File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1857, in _call_impl
[rank1]: return inner()
[rank1]: ^^^^^^^
[rank1]: File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1805, in inner
[rank1]: result = forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/liger_kernel/transformers/rms_norm.py", line 33, in forward
[rank1]: return LigerRMSNormFunction.apply(
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/autograd/function.py", line 575, in apply
[rank1]: return super().apply(*args, **kwargs) # type: ignore[misc]
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/liger_kernel/ops/utils.py", line 40, in wrapper
[rank1]: return fn(ctx, *args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/liger_kernel/ops/rms_norm.py", line 556, in forward
[rank1]: Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/liger_kernel/ops/rms_norm.py", line 400, in rms_norm_forward
[rank1]: _rms_norm_forward_kernel[(n_rows,)](
[rank1]: File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/triton/runtime/jit.py", line 347, in <lambda>
[rank1]: return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/triton/runtime/jit.py", line 591, in run
[rank1]: kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata,
[rank1]: File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/triton/backends/nvidia/driver.py", line 529, in __call__
[rank1]: self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, global_scratch, *args)
[rank1]: RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered
Reproduce
No response
Versions
Environment Report:
Operating System: Linux-6.8.0-59-generic-x86_64-with-glibc2.35 Python version: 3.11.13 Liger Kernel version: 0.6.0 PyTorch version: 2.7.1+cu126 CUDA version: 12.6 HIP(ROCm) version: Not available Triton version: 3.3.1 Transformers version: 4.53.2 XPU version: XPU Not Available