[USABILITY] PrintScriptWithAnnotation in StructuralEquality Check
Talking about my own experience debugging some of the passes. It would be very helpful to support a printer which allows printout and highlight a fragment of the IR
Example
The code below shows an example of structural equality check. Structural equality check is the basic tool for us to develop unit-test cases for passes. Usually we will write/generate IRs before/after transformation and assert that they structurally equals each other.
Normally when the pass is written as expected, structural equality runs fine. However, there can be cases where two pieces of IR do not structurally equals each other, in such case, it is usually useful to diff the two IRs(desired output and actual output) and find where does the mismatch happen.
from __future__ import annotations
import tvm
from tvm.script import tir as T, relax as R
def test_example():
@R.function
def f0(x: Tensor[(_, _), "float32"]):
x0 = R.match_shape(x, (n, m))
return (x0, (n + 1, m))
@R.function
def f1(x: Tensor[(_, _), "float32"]):
x0 = R.match_shape(x, (n,))
return (x0, (n + 1,))
tvm.ir.assert_structural_equal(f0, f1)
wish_error_message = """
StructuralEqual check failed.
See the mismatched location highlighted:
# lhs
@R.function
def f0(x: Tensor[(_, _), "float32"]):
x0 = R.match_shape(x, (n, m))
^^^^^
return (x0, (n + 1, m))
# result of rhs
@R.function
def f1(x: Tensor[(_, _), "float32"]):
x0 = R.match_shape(x, (n,))
^^^^
return (x0, (n + 1,))
"""
test_example()
Right now the structural equality check print out the (first) object that mismatches
However, it is usually more helpful to be able to highlight the object under the global context. The wished_error_message in the above example gives an instance of such case.
To support this printing, we should ideally support a function
String PrintRelaxScriptWithAnnotation(ObjectRef input, Array<ObjectRef> annotations);
This will print out the script with certain annotated location(which contains the IR object of interest). Then we just need to update the structural equality to call into this function to print any IR fragment with highlight. Note that this would also be useful for other error reporting cases.
Possible ways to implement the support
There are a few possible ways to implement this. One way is certainly to bake it into the feature of the printer. Notably, not all the IR are printable(e.g. those in the meta-data), so we also need to be able to detect the most relevant parent that holds the info and print out the specific parent context plus the location information(like type).
The span information was normally used to preserve context in the frontend, one possibility could be to print and use parser to re-populate the span info. However, special printer support may still be desirable due to the above mentioned reasons(some of the object may not directly have a span if they are in part of meta-data or sugar)
CC: @yelite