triton icon indicating copy to clipboard operation
triton copied to clipboard

AssertionError("cannot convert NotImplemented of type <class 'NotImplementedType'> to tensor")

Open LukeLIN-web opened this issue 10 months ago • 7 comments

I have a easy example.

M = N = K = n =512
dev='cuda'

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_N': 256}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=8),
    ],
    key=['N'],
)

@triton.jit
def mm1_kernel(a_ptr, b_ptr, c_ptr, N, 
            BLOCK_SIZE_N: tl.constexpr):
    mid = tl.program_id(0)
    nid = tl.program_id(1)
    # Starting row + BLOCK_SIZE_N more rows
    a_rows = mid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    # Starting col + BLOCK_SIZE_N more columns
    b_cols = nid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)

    a_ptrs = a_ptr + a_rows[:, None] * K + tl.arange(0, BLOCK_SIZE_N)[None, :]
    b_ptrs = b_ptr + tl.arange(0, BLOCK_SIZE_N)[:, None] * N + b_cols[None, :]

    c = tl.zeros([BLOCK_SIZE_N, BLOCK_SIZE_N], dtype=tl.float32)
    for k in range(K//BLOCK_SIZE_N):
        a = tl.load(a_ptrs)
        b = tl.load(b_ptrs)
        c += tl.dot(a, b)
        a_ptrs += BLOCK_SIZE_N
        b_ptrs += BLOCK_SIZE_N * N

    c = c.to(tl.float16)

    # C's block's offsets
    c_ptrs = a_rows[:, None] * N + b_cols[None, :]
    tl.store(c_ptr+ c_ptrs, c)


def mm1(a, b):
    c = torch.empty([M, N], device=a.device, dtype=a.dtype)
    grid = lambda META: (
        triton.cdiv(N, META['BLOCK_SIZE_N']),
    )
    mm1_kernel[grid](a, b, c,N)
    return c

torch.manual_seed(0)
a = torch.randn((n, n), device='cuda', dtype=torch.float16)
b = torch.randn((n, n), device='cuda', dtype=torch.float16)
triton_output = mm1(a,b)

However, it shows

Traceback (most recent call last):
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1222, in ast_to_ttir
    generator.visit(fn.parse())
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
          ^^^^^^^^^^^^^^^^^^^
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 303, in visit_Module
    ast.NodeVisitor.generic_visit(self, node)
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/ast.py", line 426, in generic_visit
    self.visit(item)
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
          ^^^^^^^^^^^^^^^^^^^
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 376, in visit_FunctionDef
    self.visit_compound_statement(node.body)
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 298, in visit_compound_statement
    ret_type = self.visit(stmt)
               ^^^^^^^^^^^^^^^^
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
          ^^^^^^^^^^^^^^^^^^^
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 855, in visit_For
    ub = language.core._to_tensor(ub, self.builder)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/language/core.py", line 70, in _to_tensor
    assert False, f"cannot convert {x} of type {type(x)} to tensor"
AssertionError: cannot convert NotImplemented of type <class 'NotImplementedType'> to tensor

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/linj/matrix/opt-gemm/tilingtriton_strassen.py", line 66, in <module>
    triton_output = mm1(a,b)
                    ^^^^^^^^
  File "/home/linj/matrix/opt-gemm/tilingtriton_strassen.py", line 51, in mm1
    mm1_kernel[grid](a, b, c,N)
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/runtime/jit.py", line 167, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 143, in run
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 143, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 122, in _bench
    return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/testing.py", line 102, in do_bench
    fn()
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 110, in kernel_call
    self.fn.run(
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/runtime/jit.py", line 416, in run
    self.cache[device][key] = compile(
                              ^^^^^^^^
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/compiler.py", line 191, in compile
    module = src.make_ir(options)
             ^^^^^^^^^^^^^^^^^^^^
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/compiler.py", line 117, in make_ir
    return ast_to_ttir(self.fn, self, options=options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/code_generator.py", line 1231, in ast_to_ttir
    raise CompilationError(fn.src, node, repr(e)) from e
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/errors.py", line 29, in __init__
    self.message = self._format_message()
                   ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/errors.py", line 13, in _format_message
    source_excerpt = self.src.split('\n')[:node.lineno][-self.source_line_count_max_in_message:]
                                           ^^^^^^^^^^^
AttributeError: 'Constant' object has no attribute 'lineno'

Similar things also happen in #2211

I am using triton 2.3.0, torch 2.2.2, python 3.11.8, cuda_12.1.

LukeLIN-web avatar Apr 16 '24 19:04 LukeLIN-web

cc @bertmaher this might be a nice starter bug for someone

jlebar avatar Apr 16 '24 19:04 jlebar

hey @LukeLIN-web, re:

AttributeError: 'Constant' object has no attribute 'lineno'.

It's already fixed in mainline by https://github.com/openai/triton/pull/3201.

I don't know about the original assert.

AssertionError: cannot convert NotImplemented of type <class 'NotImplementedType'> to tensor

fkouteib avatar Apr 17 '24 23:04 fkouteib

hey @LukeLIN-web, re:

AttributeError: 'Constant' object has no attribute 'lineno'.

It's already fixed in mainline by #3201.

I don't know about the original assert.

AssertionError: cannot convert NotImplemented of type <class 'NotImplementedType'> to tensor

I still has this problem AttributeError: 'Constant' object has no attribute 'lineno'. I am using triton for the first time, I don't know why it has AssertionError: cannot convert NotImplemented of type <class 'NotImplementedType'> to tensor

LukeLIN-web avatar Apr 18 '24 10:04 LukeLIN-web

I still has this problem AttributeError: 'Constant' object has no attribute 'lineno'.

With which version? It was not patched in Triton v2.3 release branch from which binaries were published to PyPI. Mainline is tracking for future release v3.0 and that's what my prior comment is referring to. If you want to try a Triton build that has resolved it, see the project's GitHub readme file for how to install Triton nightly build. Nightly builds for v3.0 are not available in PyPy. You can only get v2.2 and v2.3 official releases from PyPI.

fkouteib avatar Apr 18 '24 14:04 fkouteib

I still has this problem AttributeError: 'Constant' object has no attribute 'lineno'.我仍然有这个问题 AttributeError: 'Constant' object 没有属性 'lineno'。

With which version? It was not patched in Triton v2.3 release branch from which binaries were published to PyPI. Mainline is tracking for future release v3.0 and that's what my prior comment is referring to. If you want to try a Triton build that has resolved it, see the project's GitHub readme file for how to install Triton nightly build. Nightly builds for v3.0 are not available in PyPy. You can only get v2.2 and v2.3 official releases from PyPI.使用哪个版本?它没有在Triton v2.3版本分支中修补,从该分支将二进制文件发布到PyPI。Mainline 正在跟踪未来的版本 v3.0,这就是我之前的评论所指的。如果您想尝试解决该问题的 Triton 构建,请参阅该项目的 GitHub 自述文件,了解如何安装 Triton nightly 构建。v3.0 的夜间构建在 PyPy 中不可用。您只能从 PyPI 获取 v2.2 和 v2.3 正式版本。

Thank you for your reply. I install nightly triton. But it still shows

Traceback (most recent call last):
  File "/home/linj/matrix/try.py", line 56, in <module>
    triton_output = mm1(a,b)
                    ^^^^^^^^
  File "/home/linj/matrix/try.py", line 50, in mm1
    mm1_kernel[grid](a, b, c,N)
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/runtime/jit.py", line 209, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 143, in run
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 143, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 120, in _bench
    return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/testing.py", line 103, in do_bench
    fn()
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 105, in kernel_call
    self.fn.run(
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/runtime/jit.py", line 526, in run
    kernel = self.compile(
             ^^^^^^^^^^^^^
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/compiler.py", line 272, in compile
    module = src.make_ir(options, codegen_fns, context)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/linj/miniconda3/envs/condaexample/lib/python3.11/site-packages/triton/compiler/compiler.py", line 112, in make_ir
    return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
triton.compiler.errors.CompilationError: at 14:4:
    mid = tl.program_id(0)
    nid = tl.program_id(1)
    # Starting row + BLOCK_SIZE_N more rows
    a_rows = mid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    # Starting col + BLOCK_SIZE_N more columns
    b_cols = nid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)

    a_ptrs = a_ptr + a_rows[:, None] * K + tl.arange(0, BLOCK_SIZE_N)[None, :]
    b_ptrs = b_ptr + tl.arange(0, BLOCK_SIZE_N)[:, None] * N + b_cols[None, :]

    c = tl.zeros([BLOCK_SIZE_N, BLOCK_SIZE_N], dtype=tl.float32)
    for k in range(K//BLOCK_SIZE_N):
    ^
AssertionError("cannot convert NotImplemented of type <class 'NotImplementedType'> to tensor")

LukeLIN-web avatar Apr 20 '24 18:04 LukeLIN-web

cannot convert NotImplemented of type <class 'NotImplementedType'> to tensor

This is due to K not being a tl.constexpr. If you pass it in as an argument (whether as a constexpr or a tensor) the example succeeds. The compiler is essentially evaluating a.__floordiv__(b). If a is a plain int, it does not know how to divide by a constexpr, hence the NotImplementedType.

The error message for not-implemented operations should definitely be better, but I'm wondering if we should auto-promote raw ints to constexprs / tensors here. I also noticed that if we reverse the order of the arguments, i.e. if in a // b, we have a as a constexpr but b as a raw int, that the compilation succeeds (a // b gets evaluated to be another constexpr). Maybe that's fine, not sure...

int3 avatar Apr 24 '24 16:04 int3

@Jokeren @fkouteib do you have input on what the right behavior is for raw ints here? Should we raise an error when encountering them?

int3 avatar Apr 26 '24 17:04 int3