tvm
tvm copied to clipboard
[Bug] Meta Schedule Layout Rewrite Failure
After #12720 the RewriteLayout postprocessor seems to fail during tuning. An example to reproduce is here https://gist.github.com/zxybazh/6bff29ae4e7cb273d57bb30599790008. And the failing message looks like:
[11:43:50] /home/zxybazh/tvm-tensorir/src/meta_schedule/search_strategy/../utils.h:289: Warning: ThreadedTraceApply::Apply failed with error [11:43:50] /home/zxybazh/tvm-tensorir/include/tvm/runtime/container/map.h:376:
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
Check failed: (itr.index < size_) is false: IndexError: key is not in Map
CC @vinx13 @junrushao @Lunderberg
I can reproduce the error, and made a simplified test script below. Playing around a bit with the shape, it looks like it only occurs when there's a dimension of size 1.
#!/usr/bin/env python3
import tvm
import tvm.testing
from tvm.script import tir as T
shape = tvm.testing.parameter(
(1, 1, 256, 512),
(1, 2, 256, 512),
(2, 1, 256, 512),
(2, 2, 256, 512),
)
def test_rewrite_layout(shape):
n, c, h, w = shape
@tvm.script.ir_module
class mod:
@T.prim_func
def main(a: T.handle):
A = T.match_buffer(a, shape, "float32")
T.func_attr({"layout_free_buffers": [0]})
for ax0, ax1, ax2, ax3 in T.grid(n, c, h, w):
with T.block():
v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.evaluate(A[v0, v1, v2, v3])
sch = tvm.tir.Schedule(mod)
tvm.meta_schedule.postproc.RewriteLayout().apply(sch)
if __name__ == "__main__":
tvm.testing.main()
It looks like the bug follows the steps below:
SuggestIndexMapfirst flattens the indicesSplitExprCollector::Collectunpacks the flattened index, but doesn't pass thesimplify_trivial_iterators = falseargument. As a result, any iterator with an extent of 1 is replaced by a constant.RewriteLayoutgenerates an index mappinglambda a,b,c,d: [c,d], where it should generatelambda a,b,c,d: [a,b,c,d].LayoutTransformcallsIndexMap::NonSurjectiveInverseto determine if the specified transformation introduces any padding. (New behavior in #12720, triggering the bug.)- The mapping function is non-invertible, causing an error.
Ah, I think I found the better fix. There's some copy/paste between Inverse and NonSurjectiveInverse that I should have cleaned up when implementing it, and there's better handling in Inverse that didn't make its way into NonSurjectiveInverse. In NonSurjectiveInverse, the following section should be used. This way, for any extent = 1 ranges, it still appropriately generates the inverse.
With this change, my test script now passes.
// Unpack the map to an array, maintaining the same parameter order.
Array<PrimExpr> inverse_exprs;
for (int i = 0, n = (*this)->initial_indices.size(); i < n; ++i) {
Var index = (*this)->initial_indices[i];
if (is_one(initial_ranges[i]->extent) && !inverse_exprs_map.count(index)) {
inverse_exprs.push_back(initial_ranges[i]->min);
} else {
inverse_exprs.push_back(inverse_exprs_map.at(index));
}
}
In general, IndexMap::Inverse doesn’t guarantee to generate the inverse map even if the index map is (theoretically) bijective because the underlying DetectIterMap can only handle limited cases. That’s why I add the interface to pre-define the inverse map because in SuggestIndexMap we know how to construct the inverse without arithmetic analysis.
https://github.com/apache/tvm/blob/main/include/tvm/tir/index_map.h#L83
Is there any chance we can use it in NonSurjectiveMap? I think when the inverse map is pre-defined, we assume it is bijective and can be inversed, the correctness is guaranteed by the users.
Here is another test case that prints out a bunch of ThreadTraceApply errors:
import tvm
from tvm.script import tir as T
@T.prim_func
def matmul(
A: T.Buffer[(512, 512), "float32"],
B: T.Buffer[(512, 512), "float32"],
C: T.Buffer[(512, 512), "float32"],
)->None:
for i in range(512):
for j in range(512):
for k in range(512):
with T.block("update"):
C[i, j] = C[i, j] + A[i, k] * B[j, k]
s=tvm.meta_schedule.tune_tir(matmul, "llvm --num-cores 1", tvm.meta_schedule.TuneConfig(100, 32), "tmp")
I tested it with the commit before #12720 and still got a bunch of errors. Not sure if it is related to this.
In general, IndexMap::Inverse doesn’t guarantee to generate the inverse map even if the index map is (theoretically) bijective because the underlying DetectIterMap can only handle limited cases.
In principle, this should have been handled by the check on the return value of DetectIterMap. However, because trivial iterators were removed, the check for DetectIterMap returning a non-empty vector wasn't sufficient to know that all iterators were present.
Is there any chance we can use it in NonSurjectiveMap?
Possibly, though I think that's unrelated to the current bug. That said, I think its use will come along for free when merging the implementations of Inverse and NonSurjectiveInverse.
@tkonolige I tested it, and had the same behavior. I was able to get better error messages by changing this line from tir::ScheduleErrorRenderLevel::kNone to tir::ScheduleErrorRenderLevel::kDetail, and it looks like the error is thrown from this check when trying to parallelize the function. That's as far as I got from now, without digging too much into the meta scheduler.
@tkonolige I tested it, and had the same behavior. I was able to get better error messages by changing this line from
tir::ScheduleErrorRenderLevel::kNonetotir::ScheduleErrorRenderLevel::kDetail, and it looks like the error is thrown from this check when trying to parallelize the function. That's as far as I got from now, without digging too much into the meta scheduler.
This is indeed an orthogonal but serious problem. @tkonolige would you like to create a separate issue and bisect to determine which commit causes this bug?
The original issue has been fixed now.