triton
triton copied to clipboard
AssertionError("cannot convert NotImplemented of type <class 'NotImplementedType'> to tensor")
I have a easy example.
M = N = K = n =512
dev='cuda'
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_N': 256}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=8),
],
key=['N'],
)
@triton.jit
def mm1_kernel(a_ptr, b_ptr, c_ptr, N,
BLOCK_SIZE_N: tl.constexpr):
mid = tl.program_id(0)
nid = tl.program_id(1)
# Starting row + BLOCK_SIZE_N more rows
a_rows = mid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
# Starting col + BLOCK_SIZE_N more columns
b_cols = nid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
a_ptrs = a_ptr + a_rows[:, None] * K + tl.arange(0, BLOCK_SIZE_N)[None, :]
b_ptrs = b_ptr + tl.arange(0, BLOCK_SIZE_N)[:, None] * N + b_cols[None, :]
c = tl.zeros([BLOCK_SIZE_N, BLOCK_SIZE_N], dtype=tl.float32)
for k in range(K//BLOCK_SIZE_N):
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
c += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_N
b_ptrs += BLOCK_SIZE_N * N
c = c.to(tl.float16)
# C's block's offsets
c_ptrs = a_rows[:, None] * N + b_cols[None, :]
tl.store(c_ptr+ c_ptrs, c)
def mm1(a, b):
c = torch.empty([M, N], device=a.device, dtype=a.dtype)
grid = lambda META: (
triton.cdiv(N, META['BLOCK_SIZE_N']),
)
mm1_kernel[grid](a, b, c,N)
return c
torch.manual_seed(0)
a = torch.randn((n, n), device='cuda', dtype=torch.float16)
b = torch.randn((n, n), device='cuda', dtype=torch.float16)
triton_output = mm1(a,b)
However, it shows
Traceback (most recent call last):
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1222, in ast_to_ttir
generator.visit(fn.parse())
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1105, in visit
ret = super().visit(node)
^^^^^^^^^^^^^^^^^^^
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/ast.py", line 418, in visit
return visitor(node)
^^^^^^^^^^^^^
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 303, in visit_Module
ast.NodeVisitor.generic_visit(self, node)
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/ast.py", line 426, in generic_visit
self.visit(item)
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1105, in visit
ret = super().visit(node)
^^^^^^^^^^^^^^^^^^^
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/ast.py", line 418, in visit
return visitor(node)
^^^^^^^^^^^^^
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 376, in visit_FunctionDef
self.visit_compound_statement(node.body)
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 298, in visit_compound_statement
ret_type = self.visit(stmt)
^^^^^^^^^^^^^^^^
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1105, in visit
ret = super().visit(node)
^^^^^^^^^^^^^^^^^^^
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/ast.py", line 418, in visit
return visitor(node)
^^^^^^^^^^^^^
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 855, in visit_For
ub = language.core._to_tensor(ub, self.builder)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/language/core.py", line 70, in _to_tensor
assert False, f"cannot convert {x} of type {type(x)} to tensor"
AssertionError: cannot convert NotImplemented of type <class 'NotImplementedType'> to tensor
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/linj/matrix/opt-gemm/tilingtriton_strassen.py", line 66, in <module>
triton_output = mm1(a,b)
^^^^^^^^
File "/home/linj/matrix/opt-gemm/tilingtriton_strassen.py", line 51, in mm1
mm1_kernel[grid](a, b, c,N)
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/runtime/jit.py", line 167, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 143, in run
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 143, in <dictcomp>
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 122, in _bench
return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/testing.py", line 102, in do_bench
fn()
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 110, in kernel_call
self.fn.run(
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/runtime/jit.py", line 416, in run
self.cache[device][key] = compile(
^^^^^^^^
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/compiler.py", line 191, in compile
module = src.make_ir(options)
^^^^^^^^^^^^^^^^^^^^
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/compiler.py", line 117, in make_ir
return ast_to_ttir(self.fn, self, options=options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1231, in ast_to_ttir
raise CompilationError(fn.src, node, repr(e)) from e
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/errors.py", line 29, in __init__
self.message = self._format_message()
^^^^^^^^^^^^^^^^^^^^^^
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/errors.py", line 13, in _format_message
source_excerpt = self.src.split('\n')[:node.lineno][-self.source_line_count_max_in_message:]
^^^^^^^^^^^
AttributeError: 'Constant' object has no attribute 'lineno'
Similar things also happen in #2211
I am using triton 2.3.0, torch 2.2.2, python 3.11.8, cuda_12.1.
cc @bertmaher this might be a nice starter bug for someone
hey @LukeLIN-web, re:
AttributeError: 'Constant' object has no attribute 'lineno'.
It's already fixed in mainline by https://github.com/openai/triton/pull/3201.
I don't know about the original assert.
AssertionError: cannot convert NotImplemented of type <class 'NotImplementedType'> to tensor
hey @LukeLIN-web, re:
AttributeError: 'Constant' object has no attribute 'lineno'.
It's already fixed in mainline by #3201.
I don't know about the original assert.
AssertionError: cannot convert NotImplemented of type <class 'NotImplementedType'> to tensor
I still has this problem AttributeError: 'Constant' object has no attribute 'lineno'
.
I am using triton for the first time, I don't know why it has AssertionError: cannot convert NotImplemented of type <class 'NotImplementedType'> to tensor
I still has this problem AttributeError: 'Constant' object has no attribute 'lineno'.
With which version? It was not patched in Triton v2.3 release branch from which binaries were published to PyPI. Mainline is tracking for future release v3.0 and that's what my prior comment is referring to. If you want to try a Triton build that has resolved it, see the project's GitHub readme file for how to install Triton nightly build. Nightly builds for v3.0 are not available in PyPy. You can only get v2.2 and v2.3 official releases from PyPI.
I still has this problem AttributeError: 'Constant' object has no attribute 'lineno'.我仍然有这个问题 AttributeError: 'Constant' object 没有属性 'lineno'。
With which version? It was not patched in Triton v2.3 release branch from which binaries were published to PyPI. Mainline is tracking for future release v3.0 and that's what my prior comment is referring to. If you want to try a Triton build that has resolved it, see the project's GitHub readme file for how to install Triton nightly build. Nightly builds for v3.0 are not available in PyPy. You can only get v2.2 and v2.3 official releases from PyPI.使用哪个版本?它没有在Triton v2.3版本分支中修补,从该分支将二进制文件发布到PyPI。Mainline 正在跟踪未来的版本 v3.0,这就是我之前的评论所指的。如果您想尝试解决该问题的 Triton 构建,请参阅该项目的 GitHub 自述文件,了解如何安装 Triton nightly 构建。v3.0 的夜间构建在 PyPy 中不可用。您只能从 PyPI 获取 v2.2 和 v2.3 正式版本。
Thank you for your reply. I install nightly triton. But it still shows
Traceback (most recent call last):
File "/home/linj/matrix/try.py", line 56, in <module>
triton_output = mm1(a,b)
^^^^^^^^
File "/home/linj/matrix/try.py", line 50, in mm1
mm1_kernel[grid](a, b, c,N)
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/runtime/jit.py", line 209, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 143, in run
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 143, in <dictcomp>
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 120, in _bench
return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/testing.py", line 103, in do_bench
fn()
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 105, in kernel_call
self.fn.run(
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/runtime/jit.py", line 526, in run
kernel = self.compile(
^^^^^^^^^^^^^
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/compiler.py", line 272, in compile
module = src.make_ir(options, codegen_fns, context)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/compiler.py", line 112, in make_ir
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
triton.compiler.errors.CompilationError: at 14:4:
mid = tl.program_id(0)
nid = tl.program_id(1)
# Starting row + BLOCK_SIZE_N more rows
a_rows = mid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
# Starting col + BLOCK_SIZE_N more columns
b_cols = nid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
a_ptrs = a_ptr + a_rows[:, None] * K + tl.arange(0, BLOCK_SIZE_N)[None, :]
b_ptrs = b_ptr + tl.arange(0, BLOCK_SIZE_N)[:, None] * N + b_cols[None, :]
c = tl.zeros([BLOCK_SIZE_N, BLOCK_SIZE_N], dtype=tl.float32)
for k in range(K//BLOCK_SIZE_N):
^
AssertionError("cannot convert NotImplemented of type <class 'NotImplementedType'> to tensor")
cannot convert NotImplemented of type <class 'NotImplementedType'> to tensor
This is due to K
not being a tl.constexpr
. If you pass it in as an argument (whether as a constexpr or a tensor) the example succeeds. The compiler is essentially evaluating a.__floordiv__(b)
. If a
is a plain int
, it does not know how to divide by a constexpr
, hence the NotImplementedType.
The error message for not-implemented operations should definitely be better, but I'm wondering if we should auto-promote raw ints to constexprs / tensors here. I also noticed that if we reverse the order of the arguments, i.e. if in a // b
, we have a
as a constexpr but b
as a raw int, that the compilation succeeds (a // b
gets evaluated to be another constexpr). Maybe that's fine, not sure...
@Jokeren @fkouteib do you have input on what the right behavior is for raw ints here? Should we raise an error when encountering them?