Liger-Kernel
Liger-Kernel copied to clipboard
torch.compile() throws exception when LigerKernel is used
🐛 Describe the bug
...
File "/home/tromero/workspace/seahorse/.venv/lib/python3.11/site-packages/torch/_inductor/async_compile.py", line 173, in triton
kernel = TritonCodeCache.load(kernel_name, source_code)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tromero/workspace/seahorse/.venv/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 3112, in load
return _module_to_triton_kernel(PyCodeCache.load(source_code), kernel_name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tromero/workspace/seahorse/.venv/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 3049, in load
return cls.load_by_key_path(key, path, linemap, attrs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tromero/workspace/seahorse/.venv/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 3062, in load_by_key_path
mod = _reload_python_module(key, path)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tromero/workspace/seahorse/.venv/lib/python3.11/site-packages/torch/_inductor/runtime/compile_tasks.py", line 45, in _reload_python_module
exec(code, mod.__dict__, mod.__dict__)
File "/tmp/torchinductor_tromero/rg/crgwelnbq5utprhg6blafhwyxbaxibrfeh7n53w5xhpi5jkmp26h.py", line 82, in <module>
_CASTING_MODE_LLAMA = constexpr[0]
^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
NameError: name 'constexpr' is not defined
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Seems to be related to these constexpr: https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/rms_norm.py#L25
Reproduce
PR that provides a test that repros the bug: https://github.com/linkedin/Liger-Kernel/pull/173
Versions
Environment Report:
Operating System: Linux-6.5.0-44-generic-x86_64-with-glibc2.35 Python version: 3.10.13 PyTorch version: 2.3.0 CUDA version: 12.1 Triton version: 2.3.0 Transformers version: 4.42.3
Liger-kernal version 0.2.1
cc @davidgonmar
I think this happens because when the FX graph is being build,
python bytecode anaylsis -> FX graph
It represents tl.constexpr params with the tl.constexpr type, meaning that symbol gets loaded into the FX graph. Torch doesn't recognize this, so I don't think applying liger first before compiling the model with torch compile will work, unless there's a cheap way to substitute the constexpr value as the value to the param directly and not specifying the param type.
If my understanding of the desired behaviour is correct, the following should fix the issue:
model = create_model(model_name).to(dtype).to("cuda")
if with_torch_compile:
model = torch.compile(model)
if with_liger is True:
kwargs = {
"rope": True,
"rms_norm": True,
"cross_entropy": True,
}
if "gemma" in model_name:
kwargs["geglu"] = True
else:
kwargs["swiglu"] = True
MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs)
Simply just re-ordering the compilation before applying the liger kernels to the compiled output.
As an example, here is the dumped FX.graph (pre-fusion) from the torch dynamo logs:
buf0: SchedulerNode(ComputedBuffer)
buf0.writes = [MemoryDep('buf0', c0, {c0: 2097152}, None)]
buf0.unmet_dependencies = []
buf0.met_dependencies =
[ MemoryDep('primals_1', c1 + 1024*tmp0, {c0: 2048, c1: 1024}, None),
MemoryDep('primals_2', c0, {c0: 2048}, None)]
buf0.users = [NodeUser(node=OUTPUT, can_inplace=False, is_weak=False)]
buf0.group.device = cuda:0
buf0.group.iteration = (2097152, 1)
buf0.sizes = ([2048, 1024], [])
primals_2_layout = FixedLayout('cuda', torch.int64, size=[16, 128], stride=[128, 1])
primals_1_layout = FixedLayout('cuda', torch.bfloat16, size=[32000, 1024], stride=[1024, 1])
buf0_layout = FixedLayout('cuda', torch.bfloat16, size=[16, 128, 1024], stride=[131072, 1024, 1])
class buf0_loop_body:
var_ranges = {z0: 2048, z1: 1024}
index0 = z0
index1 = 1024*indirect0 + z1
index2 = 1024*z0 + z1
def body(self, ops):
get_index = self.get_index('index0')
load = ops.load('primals_2', get_index)
set_indirect0 = self.set_indirect0(load)
get_index_1 = self.get_index('index1')
load_1 = ops.load('primals_1', get_index_1)
get_index_2 = self.get_index('index2')
store = ops.store('buf0', get_index_2, load_1, None)
return store
buf0 Triton code:
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[2097152],
filename=__file__,
triton_meta={'signature': {0: '*i64', 1: '*bf16', 2: '*bf16', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=86, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=82), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '521822B9F3468AF6E86806E85D57CDE414A140644FBC5A96ABF3FC1C38FC8077', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 2097152
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x1 = (xindex // 1024)
x0 = xindex % 1024
x2 = xindex
tmp0 = tl.load(in_ptr0 + (x1), None, eviction_policy='evict_last')
tmp1 = tl.full([XBLOCK], 32000, tl.int32)
tmp2 = tmp0 + tmp1
tmp3 = tmp0 < 0
tmp4 = tl.where(tmp3, tmp2, tmp0)
tl.device_assert((0 <= tmp4) & (tmp4 < 32000), "index out of bounds: 0 <= tmp4 < 32000")
tmp6 = tl.load(in_ptr1 + (x0 + (1024*tmp4)), None).to(tl.float32)
tl.store(out_ptr0 + (x2), tmp6, None)
buf1: SchedulerNode(ComputedBuffer)
buf1.writes = [MemoryDep('buf1', c0, {c0: 128}, None)]
buf1.unmet_dependencies = []
buf1.met_dependencies = []
buf1.users = [NodeUser(node=OUTPUT, can_inplace=False, is_weak=False)]
buf1.group.device = cuda:0
buf1.group.iteration = (128, 1)
buf1.sizes = ([128], [])
buf1_layout = FixedLayout('cuda', torch.int64, size=[128], stride=[1])
class buf1_loop_body:
var_ranges = {z0: 128}
index0 = z0
def body(self, ops):
get_index = self.get_index('index0')
index_expr = ops.index_expr(get_index, torch.int64)
get_index_1 = self.get_index('index0')
store = ops.store('buf1', get_index_1, index_expr, None)
return store
buf1 Triton code:
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[128],
filename=__file__,
triton_meta={'signature': {0: '*i64', 1: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=86, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=82), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': '521822B9F3468AF6E86806E85D57CDE414A140644FBC5A96ABF3FC1C38FC8077', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = x0
tl.store(out_ptr0 + (x0), tmp0, xmask)
torch fails to recognize tl.constexpr.
5/7 tests pass with this,
=============================================================================== short test summary info ===============================================================================
FAILED test/convergence/test_mini_models.py::test_mini_model_with_torch_compile[mini_gemma1.1] - torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
FAILED test/convergence/test_mini_models.py::test_mini_model_with_torch_compile[mini_qwen2] - torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
========================================================== 2 failed, 5 passed, 14 deselected, 1 warning in 82.54s (0:01:22) ===========================================================
The two that fail is because of a torch inductor issue. It doesn't recognize the symbol 's10', which is external to Liger.
Awesome. Thanks for the insight!
Maybe still worthwhile to make a PR for adding it as a test? Also interestingly when I run those two tests individually, they pass. Will have to look into why thats happening.
Ok I'll still add the test. TBH I'm surprised that patching after compiling works.
Just checked that tests passed, didn't really check if the kernels were actually substituted. Might be something to do still. But I think it should be okay?
What's the exact status on torch.compile compatibility with Liger at the moment? I checked the ordering of operations and I'm currently running into similar issues with constexpr as in this issue despite compiling prior to applying Liger for the Llama family of models (also getting the same thing with gemma models). My setup is currently single GPU on A100-80G. I think this is still related to this: https://github.com/linkedin/Liger-Kernel/blob/de12602d858a6e83aaacc56e5cb64ab218c75a0a/src/liger_kernel/ops/rms_norm.py#L39 as above.
The following is for a Llama-3.2-1B with torch.compile applied then with Liger applied after:
File "/home/ray/anaconda3/lib/python3.11/site-packages/liger_kernel/transformers/swiglu.py", line 21, in forward
LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x))
File "/home/ray/anaconda3/lib/python3.11/site-packages/torch/autograd/function.py", line 598, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ray/anaconda3/lib/python3.11/site-packages/liger_kernel/ops/utils.py", line 30, in wrapper
return fn(ctx, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ray/anaconda3/lib/python3.11/site-packages/liger_kernel/ops/swiglu.py", line 111, in forward
a, b, c = swiglu_forward(a, b)
^^^^^^^^^^^^^^^^^^^^
File "/home/ray/anaconda3/lib/python3.11/site-packages/liger_kernel/ops/swiglu.py", line 74, in swiglu_forward
_swiglu_forward_kernel[(n_rows,)](
File "/home/ray/anaconda3/lib/python3.11/site-packages/triton/runtime/jit.py", line 167, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ray/anaconda3/lib/python3.11/site-packages/triton/runtime/jit.py", line 416, in run
self.cache[device][key] = compile(
^^^^^^^^
File "/home/ray/anaconda3/lib/python3.11/site-packages/triton/compiler/compiler.py", line 191, in compile
module = src.make_ir(options)
^^^^^^^^^^^^^^^^^^^^
File "/home/ray/anaconda3/lib/python3.11/site-packages/triton/compiler/compiler.py", line 117, in make_ir
return ast_to_ttir(self.fn, self, options=options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ray/anaconda3/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1231, in ast_to_ttir
raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 4:31:def _swiglu_forward_kernel(
a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
):
program_id = tl.program_id(0).cast(tl.int64)
^
AttributeError("'tensor' object has no attribute 'cast'")
Reproduce
The same error is reproducible by running the following snippet on my current setup, and is triggered on the call to model.generate
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda"
ckpt = "meta-llama/Llama-3.2-1B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
ckpt,
torch_dtype=torch.bfloat16,
)
model.to(device)
model = torch.compile(model)
from liger_kernel.transformers import _apply_liger_kernel_to_instance
_apply_liger_kernel_to_instance(model)
tokenizer = AutoTokenizer.from_pretrained(ckpt)
prompt = "Why dogs are so cute?"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
print(inputs)
outputs = model.generate(**inputs, do_sample=False)
response = tokenizer.batch_decode(outputs)[0]
Environment Report
Operating System: Ubuntu 20.04.6 LTS Python version: 3.11.9 PyTorch version: 2.3.1 CUDA version: 12.1 Triton version: 2.3.1 Transformers version: 4.43.2
Liger-kernel version 0.3.1
There are still some caveats. We are on it
I have the same issue even without torch.compile. I can maybe also open a separate issue?
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda"
ckpt = "meta-llama/Llama-3.2-1B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
ckpt,
torch_dtype=torch.bfloat16,
)
model.to(device)
# model = torch.compile(model)
from liger_kernel.transformers import _apply_liger_kernel_to_instance
_apply_liger_kernel_to_instance(model)
tokenizer = AutoTokenizer.from_pretrained(ckpt)
prompt = "Why dogs are so cute?"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
print(inputs)
outputs = model.generate(**inputs, do_sample=False)
response = tokenizer.batch_decode(outputs)[0]
Environment Report:
Operating System: Linux-5.14.0-70.30.1.el9_0.x86_64-x86_64-with-glibc2.35 Python version: 3.10.12 PyTorch version: 2.4.0a0+3bcc3cddb5.nv24.07 CUDA version: 12.5 Triton version: 3.0.0 Transformers version: 4.45.1
I'm using the NGC PyTorch image with triton installed already
pytorch-triton 3.0.0+989adb9a2
torch 2.4.0a0+3bcc3cddb5.nv24.7
I installed liger_kernel with no build isolation and no deps as the triton package name is different
pip install --no-deps --no-build-isolation git+https://github.com/linkedin/[email protected]
Full trace:
Traceback (most recent call last):
File "/scratch/moalla/alignment-as-translation/dev/tests/bug_liger.py", line 21, in <module>
outputs = model.generate(**inputs, do_sample=False)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 2048, in generate
result = self._sample(
File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 3008, in _sample
outputs = self(**model_inputs, return_dict=True)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1552, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1561, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/liger_kernel/transformers/model/llama.py", line 82, in lce_forward
outputs = self.model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1552, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1561, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 1000, in forward
layer_outputs = decoder_layer(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1552, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1561, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 745, in forward
hidden_states = self.mlp(hidden_states)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1552, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1561, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/liger_kernel/transformers/swiglu.py", line 21, in forward
LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x))
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 573, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.10/dist-packages/liger_kernel/ops/utils.py", line 30, in wrapper
return fn(ctx, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/liger_kernel/ops/swiglu.py", line 111, in forward
a, b, c = swiglu_forward(a, b)
File "/usr/local/lib/python3.10/dist-packages/liger_kernel/ops/swiglu.py", line 74, in swiglu_forward
_swiglu_forward_kernel[(n_rows,)](
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 180, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 401, in run
self.cache[device][key] = compile(
File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 263, in compile
module = src.make_ir(options, context)
File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 108, in make_ir
return ast_to_ttir(self.fn, self, context=context, options=options)
triton.compiler.errors.CompilationError: at 4:17:
def _swiglu_forward_kernel(
a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
):
program_id = tl.program_id(0).cast(tl.int64)
^
AttributeError("'tensor' object has no attribute 'cast'")
Hmm yeah good call actually, seems like that specific error has to do with some issue with Torch 2.3.1 and Liger 0.3.1 -> realized that if you upgrade to Torch 2.4.1 this goes away. However with Torch 2.4.1 the previous issues with torch.compile + liger remain regardless of the ordering (compile before or after _apply_liger_kernel_to_instance).
File "/tmp/torchinductor_ray/yj/cyjrzusp5sr4lkc5w5qemtq5vcepmync6kc3bvwyilj4sniykftq.py", line 31, in <module>
_rms_norm_forward_kernel_0 = async_compile.triton('_rms_norm_forward_kernel', '''
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ray/anaconda3/lib/python3.11/site-packages/torch/_inductor/async_compile.py", line 173, in triton
kernel = TritonCodeCache.load(kernel_name, source_code)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ray/anaconda3/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 3123, in load
return _module_to_triton_kernel(PyCodeCache.load(source_code), kernel_name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ray/anaconda3/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 3060, in load
return cls.load_by_key_path(key, path, linemap, attrs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ray/anaconda3/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 3073, in load_by_key_path
mod = _reload_python_module(key, path)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ray/anaconda3/lib/python3.11/site-packages/torch/_inductor/runtime/compile_tasks.py", line 45, in _reload_python_module
exec(code, mod.__dict__, mod.__dict__)
File "/tmp/torchinductor_ray/wh/cwhzznnjk7pijfx2gtwkmxcuwwcflwon2kuh527use5nhkrh4r5c.py", line 82, in <module>
_CASTING_MODE_LLAMA = constexpr[0]
^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
NameError: name 'constexpr' is not defined
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Reproduce
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from liger_kernel.transformers import _apply_liger_kernel_to_instance
device = "cuda"
ckpt = "meta-llama/Llama-3.2-3B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
ckpt,
torch_dtype=torch.bfloat16,
)
model.to(device)
model = torch.compile(model)
_apply_liger_kernel_to_instance(model)
tokenizer = AutoTokenizer.from_pretrained(ckpt)
prompt = "Why dogs are so cute?"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
print(inputs)
outputs = model.forward(**inputs)
Environment Report
Operating System: Ubuntu 20.04.6 LTS Python version: 3.11.9 PyTorch version: 2.4.1 CUDA version: 12.1 Triton version: 3.0.0 Transformers version: 4.43.2
Liger-kernel version 0.3.1
Currently using
torch: 2.4.1+cu24 liger-kernel: 0.3.1 triton: 3.0.0
Am using this in a DDP setup and using torch.compile. The code works and model works (training + inference) however the torch dynamo stack trace is beyond annoying. I'll attach mine here if it helps anyone on the dev side!
W1024 01:50:24.288000 130062018127680 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/runtime/compile_tasks.py", line 45, in _reload_python_module
W1024 01:50:24.288000 130062018127680 torch/_dynamo/convert_frame.py:1009] exec(code, mod.__dict__, mod.__dict__)
W1024 01:50:24.288000 130062018127680 torch/_dynamo/convert_frame.py:1009] File "/tmp/torchinductor_root/gh/cghzaah6fo2pnkla3kbhkyfukp2m6qdqpzxn5xnl62vjigkhzg3b.py", line 71, in <module>
W1024 01:50:24.288000 130062018127680 torch/_dynamo/convert_frame.py:1009] async_compile.wait(globals())
W1024 01:50:24.288000 130062018127680 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/async_compile.py", line 247, in wait
W1024 01:50:24.288000 130062018127680 torch/_dynamo/convert_frame.py:1009] scope[key] = result.result()
W1024 01:50:24.288000 130062018127680 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^
W1024 01:50:24.288000 130062018127680 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 3432, in result
W1024 01:50:24.288000 130062018127680 torch/_dynamo/convert_frame.py:1009] result = self.future.result()
W1024 01:50:24.288000 130062018127680 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^
W1024 01:50:24.288000 130062018127680 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/concurrent/futures/_base.py", line 456, in result
W1024 01:50:24.288000 130062018127680 torch/_dynamo/convert_frame.py:1009] return self.__get_result()
W1024 01:50:24.288000 130062018127680 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^
W1024 01:50:24.288000 130062018127680 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
W1024 01:50:24.288000 130062018127680 torch/_dynamo/convert_frame.py:1009] raise self._exception
W1024 01:50:24.288000 130062018127680 torch/_dynamo/convert_frame.py:1009] torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
W1024 01:50:24.288000 130062018127680 torch/_dynamo/convert_frame.py:1009] CompilationError: at 6:11:
W1024 01:50:24.288000 130062018127680 torch/_dynamo/convert_frame.py:1009] def triton_(out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
W1024 01:50:24.288000 130062018127680 torch/_dynamo/convert_frame.py:1009] xnumel = 1
W1024 01:50:24.288000 130062018127680 torch/_dynamo/convert_frame.py:1009] xoffset = tl.program_id(0) * XBLOCK
W1024 01:50:24.288000 130062018127680 torch/_dynamo/convert_frame.py:1009] xindex = xoffset + tl.arange(0, XBLOCK)[:]
W1024 01:50:24.288000 130062018127680 torch/_dynamo/convert_frame.py:1009] xmask = xindex < xnumel
W1024 01:50:24.288000 130062018127680 torch/_dynamo/convert_frame.py:1009] tmp0 = (256 / ks0) ** 1.00787401574803
W1024 01:50:24.288000 130062018127680 torch/_dynamo/convert_frame.py:1009] ^
W1024 01:50:24.288000 130062018127680 torch/_dynamo/convert_frame.py:1009] AttributeError("'tensor' object has no attribute '__pow__'")
W1024 01:50:24.288000 130062018127680 torch/_dynamo/convert_frame.py:1009]
W1024 01:50:24.288000 130062018127680 torch/_dynamo/convert_frame.py:1009] Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
W1024 01:50:24.288000 130062018127680 torch/_dynamo/convert_frame.py:1009]
So even though theres a supposed error, the entire model works (no clue how) but if this can get addressed then that would be amazing!!!
Also I have added this at the top level of my train and inference scripts
import torch._dynamo
torch._dynamo.config.suppress_errors = True
torch.set_float32_matmul_precision('high')
I am seeing the same issue, stacktrace:
[06:08:30.735]: File "/opt/conda/lib/python3.11/site-packages/trl/trainer/sft_trainer.py", line 361, in train
[06:08:30.735]: output = super().train(*args, **kwargs)
[06:08:30.735]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.735]: File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 2052, in train
[06:08:30.735]: return inner_training_loop(
[06:08:30.735]: ^^^^^^^^^^^^^^^^^^^^
[06:08:30.735]: File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 2388, in _inner_training_loop
[06:08:30.736]: tr_loss_step = self.training_step(model, inputs)
[06:08:30.736]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.736]: File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 3485, in training_step
[06:08:30.736]: loss = self.compute_loss(model, inputs)
[06:08:30.736]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.736]: File "/host_home/research/code/research_mono/research_mono/math_ai/trainers/trl_sft.py", line 133, in compute_loss
[06:08:30.736]: loss, outputs = super().compute_loss(model=model, inputs=inputs, return_outputs=True)
[06:08:30.736]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.736]: File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 3532, in compute_loss
[06:08:30.736]: outputs = model(**inputs)
[06:08:30.736]: ^^^^^^^^^^^^^^^
[06:08:30.736]: File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[06:08:30.736]: return self._call_impl(*args, **kwargs)
[06:08:30.736]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.736]: File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[06:08:30.737]: return forward_call(*args, **kwargs)
[06:08:30.737]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.737]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
[06:08:30.737]: return fn(*args, **kwargs)
[06:08:30.737]: ^^^^^^^^^^^^^^^^^^^
[06:08:30.737]: File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[06:08:30.737]: return self._call_impl(*args, **kwargs)
[06:08:30.737]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.737]: File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[06:08:30.737]: return forward_call(*args, **kwargs)
[06:08:30.737]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.737]: File "/opt/conda/lib/python3.11/site-packages/accelerate/utils/operations.py", line 820, in forward
[06:08:30.737]: return model_forward(*args, **kwargs)
[06:08:30.737]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.737]: File "/opt/conda/lib/python3.11/site-packages/accelerate/utils/operations.py", line 808, in __call__
[06:08:30.737]: return convert_to_fp32(self.model_forward(*args, **kwargs))
[06:08:30.737]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.737]: File "/opt/conda/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
[06:08:30.737]: return func(*args, **kwargs)
[06:08:30.737]: ^^^^^^^^^^^^^^^^^^^^^
[06:08:30.737]: File "/opt/conda/lib/python3.11/site-packages/liger_kernel/transformers/model/llama.py", line 21, in lce_forward
[06:08:30.737]: @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
[06:08:30.737]: File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[06:08:30.737]: return self._call_impl(*args, **kwargs)
[06:08:30.737]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.737]: File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[06:08:30.737]: return forward_call(*args, **kwargs)
[06:08:30.737]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.737]: File "/opt/conda/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 913, in forward
[06:08:30.737]: @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
[06:08:30.737]: File "/opt/conda/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1000, in torch_dynamo_resume_in_forward_at_970
[06:08:30.737]: layer_outputs = decoder_layer(
[06:08:30.737]: ^^^^^^^^^^^^^^
[06:08:30.737]: File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[06:08:30.738]: return self._call_impl(*args, **kwargs)
[06:08:30.738]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.738]: File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[06:08:30.738]: return forward_call(*args, **kwargs)
[06:08:30.738]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.738]: File "/opt/conda/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 690, in forward
[06:08:30.738]: def forward(
[06:08:30.738]: File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[06:08:30.738]: return self._call_impl(*args, **kwargs)
[06:08:30.738]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.738]: File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[06:08:30.738]: return forward_call(*args, **kwargs)
[06:08:30.738]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.738]: File "/opt/conda/lib/python3.11/site-packages/liger_kernel/transformers/rms_norm.py", line 25, in forward
[06:08:30.738]: def forward(self, hidden_states):
[06:08:30.738]: File "/opt/conda/lib/python3.11/site-packages/torch/autograd/function.py", line 574, in apply
[06:08:30.738]: return super().apply(*args, **kwargs) # type: ignore[misc]
[06:08:30.738]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.738]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__
[06:08:30.738]: return self._torchdynamo_orig_callable(
[06:08:30.738]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.738]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 948, in __call__
[06:08:30.738]: result = self._inner_convert(
[06:08:30.738]: ^^^^^^^^^^^^^^^^^^^^
[06:08:30.738]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
[06:08:30.738]: return _compile(
[06:08:30.738]: ^^^^^^^^^
[06:08:30.738]: File "/opt/conda/lib/python3.11/site-packages/torch/_utils_internal.py", line 84, in wrapper_function
[06:08:30.738]: return StrobelightCompileTimeProfiler.profile_compile_time(
[06:08:30.738]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.738]: File "/opt/conda/lib/python3.11/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
[06:08:30.738]: return func(*args, **kwargs)
[06:08:30.738]: ^^^^^^^^^^^^^^^^^^^^^
[06:08:30.738]: File "/opt/conda/lib/python3.11/contextlib.py", line 81, in inner
[06:08:30.738]: return func(*args, **kwds)
[06:08:30.738]: ^^^^^^^^^^^^^^^^^^^
[06:08:30.738]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 817, in _compile
[06:08:30.739]: guarded_code = compile_inner(code, one_graph, hooks, transform)
[06:08:30.739]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.739]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
[06:08:30.739]: r = func(*args, **kwargs)
[06:08:30.739]: ^^^^^^^^^^^^^^^^^^^^^
[06:08:30.739]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner
[06:08:30.739]: out_code = transform_code_object(code, transform)
[06:08:30.739]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.739]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object
[06:08:30.739]: transformations(instructions, code_options)
[06:08:30.739]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 178, in _fn
[06:08:30.739]: return fn(*args, **kwargs)
[06:08:30.739]: ^^^^^^^^^^^^^^^^^^^
[06:08:30.739]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 582, in transform
[06:08:30.739]: tracer.run()
[06:08:30.739]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2451, in run
[06:08:30.739]: super().run()
[06:08:30.739]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
[06:08:30.739]: while self.step():
[06:08:30.739]: ^^^^^^^^^^^
[06:08:30.739]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
[06:08:30.739]: self.dispatch_table[inst.opcode](self, inst)
[06:08:30.739]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2642, in RETURN_VALUE
[06:08:30.739]: self._return(inst)
[06:08:30.739]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2627, in _return
[06:08:30.739]: self.output.compile_subgraph(
[06:08:30.739]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1123, in compile_subgraph
[06:08:30.740]: self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
[06:08:30.740]: File "/opt/conda/lib/python3.11/contextlib.py", line 81, in inner
[06:08:30.740]: return func(*args, **kwds)
[06:08:30.740]: ^^^^^^^^^^^^^^^^^^^
[06:08:30.740]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1318, in compile_and_call_fx_graph
[06:08:30.740]: compiled_fn = self.call_user_compiler(gm)
[06:08:30.740]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.740]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
[06:08:30.740]: r = func(*args, **kwargs)
[06:08:30.740]: ^^^^^^^^^^^^^^^^^^^^^
[06:08:30.740]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1409, in call_user_compiler
[06:08:30.740]: raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
[06:08:30.740]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1390, in call_user_compiler
[06:08:30.740]: compiled_fn = compiler_fn(gm, self.example_inputs())
[06:08:30.740]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.740]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__
[06:08:30.740]: compiled_gm = compiler_fn(gm, example_inputs)
[06:08:30.740]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.740]: File "/opt/conda/lib/python3.11/site-packages/torch/__init__.py", line 1951, in __call__
[06:08:30.740]: return compile_fx(model_, inputs_, config_patches=self.config)
[06:08:30.740]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.740]: File "/opt/conda/lib/python3.11/contextlib.py", line 81, in inner
[06:08:30.740]: return func(*args, **kwds)
[06:08:30.740]: ^^^^^^^^^^^^^^^^^^^
[06:08:30.740]: File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1505, in compile_fx
[06:08:30.740]: return aot_autograd(
[06:08:30.740]: ^^^^^^^^^^^^^
[06:08:30.740]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/backends/common.py", line 69, in __call__
[06:08:30.740]: cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
[06:08:30.740]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.740]: File "/opt/conda/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 954, in aot_module_simplified
[06:08:30.740]: compiled_fn, _ = create_aot_dispatcher_function(
[06:08:30.740]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.740]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
[06:08:30.741]: r = func(*args, **kwargs)
[06:08:30.741]: ^^^^^^^^^^^^^^^^^^^^^
[06:08:30.741]: File "/opt/conda/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 687, in create_aot_dispatcher_function
[06:08:30.741]: compiled_fn, fw_metadata = compiler_fn(
[06:08:30.741]: ^^^^^^^^^^^^
[06:08:30.741]: File "/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 168, in aot_dispatch_base
[06:08:30.741]: compiled_fw = compiler(fw_module, updated_flat_args)
[06:08:30.741]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.741]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
[06:08:30.741]: r = func(*args, **kwargs)
[06:08:30.741]: ^^^^^^^^^^^^^^^^^^^^^
[06:08:30.741]: File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1410, in fw_compiler_base
[06:08:30.741]: return inner_compile(
[06:08:30.741]: ^^^^^^^^^^^^^^
[06:08:30.741]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/repro/after_aot.py", line 84, in debug_wrapper
[06:08:30.741]: inner_compiled_fn = compiler_fn(gm, example_inputs)
[06:08:30.741]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.741]: File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/debug.py", line 304, in inner
[06:08:30.741]: return fn(*args, **kwargs)
[06:08:30.741]: ^^^^^^^^^^^^^^^^^^^
[06:08:30.741]: File "/opt/conda/lib/python3.11/contextlib.py", line 81, in inner
[06:08:30.741]: return func(*args, **kwds)
[06:08:30.741]: ^^^^^^^^^^^^^^^^^^^
[06:08:30.741]: File "/opt/conda/lib/python3.11/contextlib.py", line 81, in inner
[06:08:30.741]: return func(*args, **kwds)
[06:08:30.741]: ^^^^^^^^^^^^^^^^^^^
[06:08:30.741]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
[06:08:30.741]: r = func(*args, **kwargs)
[06:08:30.741]: ^^^^^^^^^^^^^^^^^^^^^
[06:08:30.741]: File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 527, in compile_fx_inner
[06:08:30.741]: compiled_graph = fx_codegen_and_compile(
[06:08:30.741]: ^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.741]: File "/opt/conda/lib/python3.11/contextlib.py", line 81, in inner
[06:08:30.741]: return func(*args, **kwds)
[06:08:30.741]: ^^^^^^^^^^^^^^^^^^^
[06:08:30.741]: File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 831, in fx_codegen_and_compile
[06:08:30.741]: compiled_fn = graph.compile_to_fn()
[06:08:30.741]: ^^^^^^^^^^^^^^^^^^^^^
[06:08:30.741]: File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1749, in compile_to_fn
[06:08:30.741]: return self.compile_to_module().call
[06:08:30.741]: ^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.741]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
[06:08:30.741]: r = func(*args, **kwargs)
[06:08:30.741]: ^^^^^^^^^^^^^^^^^^^^^
[06:08:30.741]: File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1699, in compile_to_module
[06:08:30.741]: mod = PyCodeCache.load_by_key_path(
[06:08:30.741]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.741]: File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 3062, in load_by_key_path
[06:08:30.742]: mod = _reload_python_module(key, path)
[06:08:30.742]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.742]: File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/runtime/compile_tasks.py", line 45, in _reload_python_module
[06:08:30.742]: exec(code, mod.__dict__, mod.__dict__)
[06:08:30.742]: File "/tmp/torchinductor_jovyan/u2/cu2zxzdpkztlotnmkuornqe3gmuke2rh2nix7vvextvdnzcwghda.py", line 67, in <module>
[06:08:30.742]: _rms_norm_forward_kernel_0 = async_compile.triton('_rms_norm_forward_kernel', '''
[06:08:30.742]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.742]: File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/async_compile.py", line 173, in triton
[06:08:30.742]: kernel = TritonCodeCache.load(kernel_name, source_code)
[06:08:30.742]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.742]: File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 3112, in load
[06:08:30.742]: return _module_to_triton_kernel(PyCodeCache.load(source_code), kernel_name)
[06:08:30.742]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.742]: File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 3049, in load
[06:08:30.742]: return cls.load_by_key_path(key, path, linemap, attrs)
[06:08:30.742]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.742]: File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 3062, in load_by_key_path
[06:08:30.742]: mod = _reload_python_module(key, path)
[06:08:30.743]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[06:08:30.743]: File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/runtime/compile_tasks.py", line 45, in _reload_python_module
[06:08:30.743]: exec(code, mod.__dict__, mod.__dict__)
[06:08:30.743]: File "/tmp/torchinductor_jovyan/lv/clveyjbklq3ogkc2uykxbw2p4qn4r6bwsoirxoknduuocs76ppwv.py", line 82, in <module>
[06:08:30.743]: _CASTING_MODE_LLAMA = constexpr[0]
[06:08:30.743]: ^^^^^^^^^
[06:08:30.743]: torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
[06:08:30.743]: NameError: name 'constexpr' is not defined
[06:08:30.743]:
[06:08:30.743]: Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
[06:08:30.743]:
[06:08:30.743]:
[06:08:30.743]: You can suppress this exception and fall back to eager by setting:
[06:08:30.743]: import torch._dynamo
[06:08:30.743]: torch._dynamo.config.suppress_errors = True
So has anybody success in doing this? The README says:
Since Liger Kernel is 100% Triton-based, it works seamlessly with torch.compile.
So I guess this should be supported :(
Thanks for any suggestions!
Thanks for the feedback. We realize that we still have some work to do to integrate torch.compile. I will modify the description in readme to avoid confusion.
Would you happen to have an ETA by any chance? I do not mind installing the beta versions and trying it out when you have them
We are working with torch team currently. I think you can either only use torch compile or liger kernel to unblock yourself. The benefit of liger is more memory reduction. But if you only care about throughput, these two achieve similar perf
Awesome, yes it's a weird scenario because if you add any liger components and also do torch.compile then it does train but it goes to all hell during validation. My goal was to get the speed from torh.compile while getting the VRAM benefits of Liger to be honest. Being able to shave off many GB of VRAM is absolutely impressive when training.
+1 Achieving small memory + fast speed would be super great
@zachNA2 @fzyzcjy you can try to do this
- torch compile only on the model itself, but not the entire casualLM module.
- Modify casualLM's forward by importing our fused linear cross entropy to fuse linear and CE together.
Our FLCE kernel is the key to reduce the memory so much. This way you can get the benefit of both torch compile and our FLCE. But it needs a bit of manual work
Thank you! I may have a try, but still looking forward to the fix.
Or is it possible to do something like, when running the patching function, add an argument that allows users to control which ones to patch. For example, maybe we can only enable CE module and disable all other modules.
You can configure whether to enable certain layer via https://github.com/linkedin/Liger-Kernel/blob/9ad8f89373b2206e86e9bb1cdc6e63c37275bd81/src/liger_kernel/transformers/monkey_patch.py#L52. However, i think HF will still attempt to compile the full causalLM, which might crash because i feel it cannot compile FLCE. You can try to use torch.compile yourself instead of using the HF argument.
@ByronHsu Yes I mean the hf one, currently there seems to only be a bool flag. Maybe we can add something like liger_kernel_kwargs: Dict to TrainingArguments.
Had the same problem of wanting to use Liger Kernels with torch.compile I followed the flash attn repo/torch doc to register Liger Implementation of RMS_Norm / Swiglu (needed thouse 2) with torch.library.custom_op and torch.library.register_fake
One also needs to refractor the wrapper functions a little bit so that functions dont return the unchanged Input pointer (return input.clone() in this case), as inplace operations on gradients seems to not be supported for backwards pass.
After that compile(fullgraph=True) seems to work, but this requites torch2.5
I was looking into torchtune and found
def compile_model(
model: Union[TransformerDecoder, DeepFusionModel],
verbose: bool = True,
) -> None:
"""
Utility to compile a transformer model inplace. On PyTorch nightlies we use per-layer compile
to reduce compile times. Otherwise we compile the full model, which takes longer.
Args:
model (Union[TransformerDecoder, DeepFusionModel]): A model to compile.
Can be a TransformerDecoder or DeepFusionModel; in the latter case only
the model's decoder will be compiled.
verbose (bool): Whether to log compile info. Default: True
Returns:
None
"""
backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")
if isinstance(model, DeepFusionModel):
model = model.decoder
# Per-layer compilation by default
if verbose:
log.info("Compiling model layers with torch.compile...")
for m in reversed(list(model.modules())):
if isinstance(m, TransformerSelfAttentionLayer) or isinstance(
m, TransformerCrossAttentionLayer
):
m.compile(backend=backend)
Looks like they compile the model only to the attention layer in the decoder part (If it's vision-language model not in the vision part.). They also compile the chunked CE but , I think it has to be a bit more complicated for applying to FLCE in liger-kernel.