tvm
tvm copied to clipboard
[Bug] [Relax] Build fails when applying `dlight.gpu.GeneralReduction` to `R.nn.group_norm` with dynamic shapes and `R.reshape`
Actual behavior
When building the TVMScript below using dlight.gpu.GeneralReduction(), the build fails with the following error:
InternalError: Check failed: (!divisor.is_const(0)) is false: Find divide by zero
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def reshape_norm(
inp_0: R.Tensor((1, 512, "w", "h"), dtype="float16"),
inp_1: R.Tensor((512,), dtype="float16"),
inp_2: R.Tensor((512,), dtype="float16")
)-> R.Tensor((1, 512, "w * h"), dtype="float16"):
w = T.int64()
h = T.int64()
with R.dataflow():
lv = R.reshape(inp_0, R.shape([1, 512, w * h]))
lv1 = R.nn.group_norm(data = lv, gamma = inp_1, beta = inp_2, num_groups=32, channel_axis=1, axes=[2], epsilon=9.9999999999999995e-07, center=True, scale=True)
R.output(lv1)
return lv1
- However, if I modify the input tensor
inp_0and the output tensor shape to(1, 512, "n")and remove theR.reshapeoperation, the build completes successfully without errors. - It works well with other
dlightschedules. If I removedl.gpu.GeneralReduction(), the build also completes with other dlight schedules.
Environment
- TVM Version: v0.18.0
- Commit Hash: 30b7b1c75
Steps to reproduce
import tvm
from tvm import relax
import tvm.dlight as dl
@tvm.transform.module_pass(opt_level=0)
def dynshape_build_pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule:
seq = tvm.transform.Sequential(
[
relax.backend.DispatchSampling(),
relax.backend.DispatchSortScan(),
relax.transform.LegalizeOps(),
dl.ApplyDefaultSchedule(
dl.gpu.Matmul(),
dl.gpu.GEMV(),
dl.gpu.Reduction(),
dl.gpu.GeneralReduction(),
dl.gpu.Fallback(),
),
relax.transform.RewriteDataflowReshape(),
relax.transform.ToNonDataflow(),
relax.transform.RemovePurityChecking(),
relax.transform.CallTIRRewrite(),
relax.transform.StaticPlanBlockMemory(),
relax.transform.RewriteCUDAGraph(),
relax.transform.LowerAllocTensor(),
relax.transform.KillAfterLastUse(),
relax.transform.LowerRuntimeBuiltin(),
relax.transform.ComputePrimValue(),
relax.transform.VMShapeLower(),
relax.transform.AttachGlobalSymbol(),
],
)
mod = seq(mod)
return mod
# `Module` as TVMScript in 'Actual behavior'
mod = Module
mod = relax.get_pipeline()(mod)
target = tvm.target.Target("cuda")
ex = relax.build(mod, target=target, pipeline=dynshape_build_pipeline)
cc @junrushao