tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[Bug] Meta Schedule Layout Rewrite Failure

Open zxybazh opened this issue 3 years ago • 4 comments
trafficstars

After #12720 the RewriteLayout postprocessor seems to fail during tuning. An example to reproduce is here https://gist.github.com/zxybazh/6bff29ae4e7cb273d57bb30599790008. And the failing message looks like:

[11:43:50] /home/zxybazh/tvm-tensorir/src/meta_schedule/search_strategy/../utils.h:289: Warning: ThreadedTraceApply::Apply failed with error [11:43:50] /home/zxybazh/tvm-tensorir/include/tvm/runtime/container/map.h:376: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (itr.index < size_) is false: IndexError: key is not in Map

CC @vinx13 @junrushao @Lunderberg

zxybazh avatar Sep 20 '22 22:09 zxybazh

I can reproduce the error, and made a simplified test script below. Playing around a bit with the shape, it looks like it only occurs when there's a dimension of size 1.

#!/usr/bin/env python3

import tvm
import tvm.testing
from tvm.script import tir as T

shape = tvm.testing.parameter(
    (1, 1, 256, 512),
    (1, 2, 256, 512),
    (2, 1, 256, 512),
    (2, 2, 256, 512),
)


def test_rewrite_layout(shape):
    n, c, h, w = shape

    @tvm.script.ir_module
    class mod:
        @T.prim_func
        def main(a: T.handle):
            A = T.match_buffer(a, shape, "float32")
            T.func_attr({"layout_free_buffers": [0]})
            for ax0, ax1, ax2, ax3 in T.grid(n, c, h, w):
                with T.block():
                    v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                    T.evaluate(A[v0, v1, v2, v3])

    sch = tvm.tir.Schedule(mod)
    tvm.meta_schedule.postproc.RewriteLayout().apply(sch)


if __name__ == "__main__":
    tvm.testing.main()

Lunderberg avatar Sep 21 '22 14:09 Lunderberg

It looks like the bug follows the steps below:

  1. SuggestIndexMap first flattens the indices
  2. SplitExprCollector::Collect unpacks the flattened index, but doesn't pass the simplify_trivial_iterators = false argument. As a result, any iterator with an extent of 1 is replaced by a constant.
  3. RewriteLayout generates an index mapping lambda a,b,c,d: [c,d], where it should generate lambda a,b,c,d: [a,b,c,d].
  4. LayoutTransform calls IndexMap::NonSurjectiveInverse to determine if the specified transformation introduces any padding. (New behavior in #12720, triggering the bug.)
  5. The mapping function is non-invertible, causing an error.

Lunderberg avatar Sep 21 '22 15:09 Lunderberg

Ah, I think I found the better fix. There's some copy/paste between Inverse and NonSurjectiveInverse that I should have cleaned up when implementing it, and there's better handling in Inverse that didn't make its way into NonSurjectiveInverse. In NonSurjectiveInverse, the following section should be used. This way, for any extent = 1 ranges, it still appropriately generates the inverse.

With this change, my test script now passes.

// Unpack the map to an array, maintaining the same parameter order.
  Array<PrimExpr> inverse_exprs;
  for (int i = 0, n = (*this)->initial_indices.size(); i < n; ++i) {
    Var index = (*this)->initial_indices[i];
    if (is_one(initial_ranges[i]->extent) && !inverse_exprs_map.count(index)) {
      inverse_exprs.push_back(initial_ranges[i]->min);
    } else {
      inverse_exprs.push_back(inverse_exprs_map.at(index));
    }
  }

Lunderberg avatar Sep 21 '22 15:09 Lunderberg

In general, IndexMap::Inverse doesn’t guarantee to generate the inverse map even if the index map is (theoretically) bijective because the underlying DetectIterMap can only handle limited cases. That’s why I add the interface to pre-define the inverse map because in SuggestIndexMap we know how to construct the inverse without arithmetic analysis.

https://github.com/apache/tvm/blob/main/include/tvm/tir/index_map.h#L83

Is there any chance we can use it in NonSurjectiveMap? I think when the inverse map is pre-defined, we assume it is bijective and can be inversed, the correctness is guaranteed by the users.

vinx13 avatar Sep 21 '22 15:09 vinx13

Here is another test case that prints out a bunch of ThreadTraceApply errors:

import tvm
from tvm.script import tir as T

@T.prim_func
def matmul(
    A: T.Buffer[(512, 512), "float32"],
    B: T.Buffer[(512, 512), "float32"],
    C: T.Buffer[(512, 512), "float32"],
    )->None:
    for i in range(512):
        for j in range(512):
            for k in range(512):
                with T.block("update"):
                    C[i, j] = C[i, j] + A[i, k] * B[j, k]

s=tvm.meta_schedule.tune_tir(matmul, "llvm --num-cores 1", tvm.meta_schedule.TuneConfig(100, 32), "tmp")

I tested it with the commit before #12720 and still got a bunch of errors. Not sure if it is related to this.

tkonolige avatar Sep 23 '22 21:09 tkonolige

In general, IndexMap::Inverse doesn’t guarantee to generate the inverse map even if the index map is (theoretically) bijective because the underlying DetectIterMap can only handle limited cases.

In principle, this should have been handled by the check on the return value of DetectIterMap. However, because trivial iterators were removed, the check for DetectIterMap returning a non-empty vector wasn't sufficient to know that all iterators were present.

Is there any chance we can use it in NonSurjectiveMap?

Possibly, though I think that's unrelated to the current bug. That said, I think its use will come along for free when merging the implementations of Inverse and NonSurjectiveInverse.

Lunderberg avatar Sep 26 '22 13:09 Lunderberg

@tkonolige I tested it, and had the same behavior. I was able to get better error messages by changing this line from tir::ScheduleErrorRenderLevel::kNone to tir::ScheduleErrorRenderLevel::kDetail, and it looks like the error is thrown from this check when trying to parallelize the function. That's as far as I got from now, without digging too much into the meta scheduler.

Lunderberg avatar Sep 26 '22 13:09 Lunderberg

@tkonolige I tested it, and had the same behavior. I was able to get better error messages by changing this line from tir::ScheduleErrorRenderLevel::kNone to tir::ScheduleErrorRenderLevel::kDetail, and it looks like the error is thrown from this check when trying to parallelize the function. That's as far as I got from now, without digging too much into the meta scheduler.

This is indeed an orthogonal but serious problem. @tkonolige would you like to create a separate issue and bisect to determine which commit causes this bug?

junrushao avatar Sep 26 '22 15:09 junrushao

The original issue has been fixed now.

masahi avatar Dec 08 '22 08:12 masahi