tvm
tvm copied to clipboard
[Bug] `from __future__ import annotations` breaks type annotation containing local variable
Simple reproducer is attached below. For detailed background, see https://github.com/tile-ai/tilelang/issues/1079
Expected behavior
test program exit without error
Actual behavior
$ TVM_BACKTRACE=1 python test.py
error: Unexpected type for TIR Arg: ffi.String
--> /home/yyc/repo/tvm/test.py:9:5
|
9 | def f(A: T.Buffer((M,), "float32")):
| ^^^^^^^^
Traceback (most recent call last):
File "/home/yyc/repo/tvm/python/tvm/script/parser/core/parser.py", line 309, in _wrapper
return func(self, node)
^^^^^^^^^^^^^^^^
File "/home/yyc/repo/tvm/python/tvm/script/parser/tir/parser.py", line 409, in visit_function_def
param = T.arg(arg.arg, ann)
^^^^^^^^^^^^^^^^^^^
File "/home/yyc/repo/tvm/python/tvm/script/ir_builder/tir/ir.py", line 200, in arg
return _ffi_api.Arg(name, obj) # type: ignore[attr-defined] # pylint: disable=no-member
^^^^^^^^^^^^^^^^^^^^^^^
File "python/tvm_ffi/cython/function.pxi", line 758, in core.Function.__call__
File "<unknown>", line 0, in tvm::runtime::detail::LogFatal::Entry::Finalize()
ValueError: Unexpected type for TIR Arg: ffi.String
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/yyc/repo/tvm/test.py", line 17, in <module>
f()
File "/home/yyc/repo/tvm/test.py", line 8, in f
@T.prim_func
^^^^^^^^^^^
File "/home/yyc/repo/tvm/python/tvm/script/parser/tir/entry.py", line 72, in prim_func
return decorator_wrapper(func)
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yyc/repo/tvm/python/tvm/script/parser/tir/entry.py", line 65, in decorator_wrapper
f = parse(func, utils.inspect_function_capture(func), check_well_formed=check_well_formed)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yyc/repo/tvm/python/tvm/script/parser/core/entry.py", line 103, in parse
parser.parse(extra_vars=extra_vars)
File "/home/yyc/repo/tvm/python/tvm/script/parser/core/parser.py", line 379, in parse
self.visit(node)
File "/home/yyc/repo/tvm/python/tvm/script/parser/core/parser.py", line 638, in visit
self.report_error(node, err)
File "/home/yyc/repo/tvm/python/tvm/script/parser/core/parser.py", line 587, in report_error
raise err
File "/home/yyc/repo/tvm/python/tvm/script/parser/core/parser.py", line 636, in visit
func(node)
File "/home/yyc/repo/tvm/python/tvm/script/parser/core/doc.py", line 256, in generic_visit
self.visit(value)
File "/home/yyc/repo/tvm/python/tvm/script/parser/core/parser.py", line 624, in visit
self.visit(item)
File "/home/yyc/repo/tvm/python/tvm/script/parser/core/parser.py", line 638, in visit
self.report_error(node, err)
File "/home/yyc/repo/tvm/python/tvm/script/parser/core/parser.py", line 587, in report_error
raise err
File "/home/yyc/repo/tvm/python/tvm/script/parser/core/parser.py", line 636, in visit
func(node)
File "/home/yyc/repo/tvm/python/tvm/script/parser/core/parser.py", line 684, in visit_FunctionDef
_dispatch_wrapper(func)(self, node)
File "/home/yyc/repo/tvm/python/tvm/script/parser/core/parser.py", line 311, in _wrapper
self.report_error(node, err)
File "/home/yyc/repo/tvm/python/tvm/script/parser/core/parser.py", line 607, in report_error
raise diag_err
File "/home/yyc/repo/tvm/python/tvm/script/parser/core/parser.py", line 309, in _wrapper
return func(self, node)
^^^^^^^^^^^^^^^^
File "/home/yyc/repo/tvm/python/tvm/script/parser/tir/parser.py", line 409, in visit_function_def
param = T.arg(arg.arg, ann)
^^^^^^^^^^^^^^^^^^^
File "/home/yyc/repo/tvm/python/tvm/script/ir_builder/tir/ir.py", line 200, in arg
return _ffi_api.Arg(name, obj) # type: ignore[attr-defined] # pylint: disable=no-member
^^^^^^^^^^^^^^^^^^^^^^^
File "python/tvm_ffi/cython/function.pxi", line 758, in core.Function.__call__
File "<unknown>", line 0, in tvm::runtime::detail::LogFatal::Entry::Finalize()
tvm.error.DiagnosticError: one or more error diagnostics were emitted, please check diagnostic render for output.
Environment
Any environment details, such as: Operating System, TVM version, etc
Steps to reproduce
from __future__ import annotations
from tvm.script import tir as T
def f(M=1):
@T.prim_func
def f(A: T.Buffer((M,), "float32")):
pass
return f
f()
Triage
Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).
- needs-triage