32 minute compile time for max_pool2d_with_indices
This example takes 32 minutes to compile, while typical kernels take seconds (not minutes). I suspect it is hitting some sort of pathological case in Halide.
repro.py
import halide as hl
from torch._inductor.runtime import halide_helpers
from math import inf, nan
@hl.generator(name="kernel")
class Kernel:
in_ptr0 = hl.InputBuffer(hl.Float(32), 5)
out_ptr0 = hl.OutputBuffer(hl.Float(32), 4)
out_ptr2 = hl.OutputBuffer(hl.Int(64), 4)
def generate(g):
in_ptr0 = g.in_ptr0
out_ptr0 = g.out_ptr0
out_ptr2 = g.out_ptr2
h0 = hl.Var("h0")
h1 = hl.Var("h1")
h2 = hl.Var("h2")
h3 = hl.Var("h3")
tmp0 = hl.Func("tmp0")
tmp0[h0, h1, h2, h3] = in_ptr0[0, h0, h1, h2, h3]
tmp1 = hl.Func("tmp1")
tmp1[h0, h1, h2, h3] = in_ptr0[1, h0, h1, h2, h3]
tmp2 = hl.Func("tmp2")
tmp2[h0, h1, h2, h3] = (
hl.select(
(tmp1[h0, h1, h2, h3] > hl.cast(tmp1.type(), tmp0[h0, h1, h2, h3]))
| hl.is_nan(tmp1[h0, h1, h2, h3]),
tmp1[h0, h1, h2, h3],
hl.cast(tmp1.type(), tmp0[h0, h1, h2, h3]),
)
if tmp1.type().is_float()
else hl.max(
tmp1[h0, h1, h2, h3], hl.cast(tmp1.type(), tmp0[h0, h1, h2, h3])
)
)
tmp3 = hl.Func("tmp3")
tmp3[h0, h1, h2, h3] = in_ptr0[0, 1 + h0, h1, h2, h3]
tmp4 = hl.Func("tmp4")
tmp4[h0, h1, h2, h3] = (
hl.select(
(tmp3[h0, h1, h2, h3] > hl.cast(tmp3.type(), tmp2[h0, h1, h2, h3]))
| hl.is_nan(tmp3[h0, h1, h2, h3]),
tmp3[h0, h1, h2, h3],
hl.cast(tmp3.type(), tmp2[h0, h1, h2, h3]),
)
if tmp3.type().is_float()
else hl.max(
tmp3[h0, h1, h2, h3], hl.cast(tmp3.type(), tmp2[h0, h1, h2, h3])
)
)
tmp5 = hl.Func("tmp5")
tmp5[h0, h1, h2, h3] = in_ptr0[1, 13 + h0, h1, h2, h3]
tmp6 = hl.Func("tmp6")
tmp6[h0, h1, h2, h3] = (
hl.select(
(tmp5[h0, h1, h2, h3] > hl.cast(tmp5.type(), tmp4[h0, h1, h2, h3]))
| hl.is_nan(tmp5[h0, h1, h2, h3]),
tmp5[h0, h1, h2, h3],
hl.cast(tmp5.type(), tmp4[h0, h1, h2, h3]),
)
if tmp5.type().is_float()
else hl.max(
tmp5[h0, h1, h2, h3], hl.cast(tmp5.type(), tmp4[h0, h1, h2, h3])
)
)
tmp7 = hl.Func("tmp7")
tmp7[h0, h1, h2, h3] = in_ptr0[0, 14 + h0, h1, h2, h3]
tmp8 = hl.Func("tmp8")
tmp8[h0, h1, h2, h3] = (
hl.select(
(tmp7[h0, h1, h2, h3] > hl.cast(tmp7.type(), tmp6[h0, h1, h2, h3]))
| hl.is_nan(tmp7[h0, h1, h2, h3]),
tmp7[h0, h1, h2, h3],
hl.cast(tmp7.type(), tmp6[h0, h1, h2, h3]),
)
if tmp7.type().is_float()
else hl.max(
tmp7[h0, h1, h2, h3], hl.cast(tmp7.type(), tmp6[h0, h1, h2, h3])
)
)
tmp9 = hl.Func("tmp9")
tmp9[h0, h1, h2, h3] = in_ptr0[1, 14 + h0, h1, h2, h3]
tmp10 = hl.Func("tmp10")
tmp10[h0, h1, h2, h3] = (
hl.select(
(tmp9[h0, h1, h2, h3] > hl.cast(tmp9.type(), tmp8[h0, h1, h2, h3]))
| hl.is_nan(tmp9[h0, h1, h2, h3]),
tmp9[h0, h1, h2, h3],
hl.cast(tmp9.type(), tmp8[h0, h1, h2, h3]),
)
if tmp9.type().is_float()
else hl.max(
tmp9[h0, h1, h2, h3], hl.cast(tmp9.type(), tmp8[h0, h1, h2, h3])
)
)
tmp11 = hl.Func("tmp11")
tmp11[h0, h1, h2, h3] = in_ptr0[0, h0, 1 + h1, h2, h3]
tmp12 = hl.Func("tmp12")
tmp12[h0, h1, h2, h3] = (
hl.select(
(tmp11[h0, h1, h2, h3] > hl.cast(tmp11.type(), tmp10[h0, h1, h2, h3]))
| hl.is_nan(tmp11[h0, h1, h2, h3]),
tmp11[h0, h1, h2, h3],
hl.cast(tmp11.type(), tmp10[h0, h1, h2, h3]),
)
if tmp11.type().is_float()
else hl.max(
tmp11[h0, h1, h2, h3], hl.cast(tmp11.type(), tmp10[h0, h1, h2, h3])
)
)
tmp13 = hl.Func("tmp13")
tmp13[h0, h1, h2, h3] = in_ptr0[1, h0, 1 + h1, h2, h3]
tmp14 = hl.Func("tmp14")
tmp14[h0, h1, h2, h3] = (
hl.select(
(tmp13[h0, h1, h2, h3] > hl.cast(tmp13.type(), tmp12[h0, h1, h2, h3]))
| hl.is_nan(tmp13[h0, h1, h2, h3]),
tmp13[h0, h1, h2, h3],
hl.cast(tmp13.type(), tmp12[h0, h1, h2, h3]),
)
if tmp13.type().is_float()
else hl.max(
tmp13[h0, h1, h2, h3], hl.cast(tmp13.type(), tmp12[h0, h1, h2, h3])
)
)
tmp15 = hl.Func("tmp15")
tmp15[h0, h1, h2, h3] = in_ptr0[0, 1 + h0, 1 + h1, h2, h3]
tmp16 = hl.Func("tmp16")
tmp16[h0, h1, h2, h3] = (
hl.select(
(tmp15[h0, h1, h2, h3] > hl.cast(tmp15.type(), tmp14[h0, h1, h2, h3]))
| hl.is_nan(tmp15[h0, h1, h2, h3]),
tmp15[h0, h1, h2, h3],
hl.cast(tmp15.type(), tmp14[h0, h1, h2, h3]),
)
if tmp15.type().is_float()
else hl.max(
tmp15[h0, h1, h2, h3], hl.cast(tmp15.type(), tmp14[h0, h1, h2, h3])
)
)
out_ptr0[h0, h1, h2, h3] = hl.cast(hl.Float(32), tmp16[h0, h1, h2, h3])
tmp17 = hl.Func("tmp17")
tmp17[h0, h1, h2, h3] = tmp1[h0, h1, h2, h3] > tmp0[h0, h1, h2, h3]
tmp18 = hl.Func("tmp18")
tmp18[()] = hl.cast(hl.Int(8), 1)
tmp19 = hl.Func("tmp19")
tmp19[()] = hl.cast(hl.Int(8), 0)
tmp20 = hl.Func("tmp20")
tmp20[h0, h1, h2, h3] = hl.select(
tmp17[h0, h1, h2, h3], tmp18[()], hl.cast(tmp18.type(), tmp19[()])
)
tmp21 = hl.Func("tmp21")
tmp21[h0, h1, h2, h3] = tmp3[h0, h1, h2, h3] > tmp2[h0, h1, h2, h3]
tmp22 = hl.Func("tmp22")
tmp22[()] = hl.cast(hl.Int(8), 2)
tmp23 = hl.Func("tmp23")
tmp23[h0, h1, h2, h3] = hl.select(
tmp21[h0, h1, h2, h3],
tmp22[()],
hl.cast(tmp22.type(), tmp20[h0, h1, h2, h3]),
)
tmp24 = hl.Func("tmp24")
tmp24[h0, h1, h2, h3] = tmp5[h0, h1, h2, h3] > tmp4[h0, h1, h2, h3]
tmp25 = hl.Func("tmp25")
tmp25[()] = hl.cast(hl.Int(8), 3)
tmp26 = hl.Func("tmp26")
tmp26[h0, h1, h2, h3] = hl.select(
tmp24[h0, h1, h2, h3],
tmp25[()],
hl.cast(tmp25.type(), tmp23[h0, h1, h2, h3]),
)
tmp27 = hl.Func("tmp27")
tmp27[h0, h1, h2, h3] = tmp7[h0, h1, h2, h3] > tmp6[h0, h1, h2, h3]
tmp28 = hl.Func("tmp28")
tmp28[()] = hl.cast(hl.Int(8), 4)
tmp29 = hl.Func("tmp29")
tmp29[h0, h1, h2, h3] = hl.select(
tmp27[h0, h1, h2, h3],
tmp28[()],
hl.cast(tmp28.type(), tmp26[h0, h1, h2, h3]),
)
tmp30 = hl.Func("tmp30")
tmp30[h0, h1, h2, h3] = tmp9[h0, h1, h2, h3] > tmp8[h0, h1, h2, h3]
tmp31 = hl.Func("tmp31")
tmp31[()] = hl.cast(hl.Int(8), 5)
tmp32 = hl.Func("tmp32")
tmp32[h0, h1, h2, h3] = hl.select(
tmp30[h0, h1, h2, h3],
tmp31[()],
hl.cast(tmp31.type(), tmp29[h0, h1, h2, h3]),
)
tmp33 = hl.Func("tmp33")
tmp33[h0, h1, h2, h3] = tmp11[h0, h1, h2, h3] > tmp10[h0, h1, h2, h3]
tmp34 = hl.Func("tmp34")
tmp34[()] = hl.cast(hl.Int(8), 6)
tmp35 = hl.Func("tmp35")
tmp35[h0, h1, h2, h3] = hl.select(
tmp33[h0, h1, h2, h3],
tmp34[()],
hl.cast(tmp34.type(), tmp32[h0, h1, h2, h3]),
)
tmp36 = hl.Func("tmp36")
tmp36[h0, h1, h2, h3] = tmp13[h0, h1, h2, h3] > tmp12[h0, h1, h2, h3]
tmp37 = hl.Func("tmp37")
tmp37[()] = hl.cast(hl.Int(8), 7)
tmp38 = hl.Func("tmp38")
tmp38[h0, h1, h2, h3] = hl.select(
tmp36[h0, h1, h2, h3],
tmp37[()],
hl.cast(tmp37.type(), tmp35[h0, h1, h2, h3]),
)
tmp39 = hl.Func("tmp39")
tmp39[h0, h1, h2, h3] = tmp15[h0, h1, h2, h3] > tmp14[h0, h1, h2, h3]
tmp40 = hl.Func("tmp40")
tmp40[()] = hl.cast(hl.Int(8), 8)
tmp41 = hl.Func("tmp41")
tmp41[h0, h1, h2, h3] = hl.select(
tmp39[h0, h1, h2, h3],
tmp40[()],
hl.cast(tmp40.type(), tmp38[h0, h1, h2, h3]),
)
tmp42 = hl.Func("tmp42")
tmp42[()] = hl.cast(hl.Int(32), 3)
tmp43 = hl.Func("tmp43")
tmp43[h0, h1, h2, h3] = hl.floor(
hl.cast(hl.Float(max(32, tmp41.type().bits())), tmp41[h0, h1, h2, h3])
/ tmp42[()]
)
tmp44 = hl.Func("tmp44")
tmp44[h0, h1, h2, h3] = tmp43[h0, h1, h2, h3] * tmp42[()]
tmp45 = hl.Func("tmp45")
tmp45[h0, h1, h2, h3] = tmp41[h0, h1, h2, h3] - tmp44[h0, h1, h2, h3]
tmp46 = hl.Func("tmp46")
tmp46[h1] = 2 * h1
tmp47 = hl.Func("tmp47")
tmp47[h0, h1, h2, h3] = tmp46[h1] + tmp43[h0, h1, h2, h3]
tmp48 = hl.Func("tmp48")
tmp48[h0] = 2 * h0
tmp49 = hl.Func("tmp49")
tmp49[h0, h1, h2, h3] = tmp48[h0] + tmp45[h0, h1, h2, h3]
tmp50 = hl.Func("tmp50")
tmp50[()] = hl.cast(hl.Int(64), 27)
tmp51 = hl.Func("tmp51")
tmp51[h0, h1, h2, h3] = tmp47[h0, h1, h2, h3] * tmp50[()]
tmp52 = hl.Func("tmp52")
tmp52[h0, h1, h2, h3] = tmp51[h0, h1, h2, h3] + tmp49[h0, h1, h2, h3]
out_ptr2[h0, h1, h2, h3] = hl.cast(hl.Int(64), tmp52[h0, h1, h2, h3])
assert g.using_autoscheduler()
in_ptr0.dim(0).set_min(0)
in_ptr0.dim(0).set_stride(1)
in_ptr0.dim(0).set_extent(2)
in_ptr0.dim(1).set_min(0)
in_ptr0.dim(1).set_stride(2)
in_ptr0.dim(1).set_extent(13)
in_ptr0.dim(2).set_min(0)
in_ptr0.dim(2).set_stride(54)
in_ptr0.dim(2).set_extent(13)
in_ptr0.dim(3).set_min(0)
in_ptr0.dim(3).set_stride(729)
in_ptr0.dim(3).set_extent(192)
in_ptr0.dim(4).set_min(0)
in_ptr0.dim(4).set_stride(139968)
in_ptr0.dim(4).set_extent(128)
in_ptr0.set_estimates(
[
hl.Range(0, 2),
hl.Range(0, 13),
hl.Range(0, 13),
hl.Range(0, 192),
hl.Range(0, 128),
]
)
out_ptr0.set_estimates(
[hl.Range(0, 13), hl.Range(0, 13), hl.Range(0, 192), hl.Range(0, 128)]
)
out_ptr2.set_estimates(
[hl.Range(0, 13), hl.Range(0, 13), hl.Range(0, 192), hl.Range(0, 128)]
)
if __name__ == "__main__":
import sys, tempfile
with tempfile.TemporaryDirectory() as out:
sys.argv = [
"repro.py",
"-g",
"kernel",
"-o",
out,
"-f",
"halide_kernel",
"-e",
"static_library,h,schedule",
"-p",
"/home/jansel/conda/envs/pytorch/lib/python3.12/site-packages/halide/lib64/libautoschedule_anderson2021.so",
"target=host-cuda-cuda_capability_86-user_context-strict_float-no_runtime-no_asserts",
"autoscheduler=Anderson2021",
"autoscheduler.parallelism=82",
]
hl.main()
cc @alexreinking this example coming from:
python benchmarks/dynamo/microbenchmarks/operatorbench.py --inductor-config autotune --inductor-config halide --op aten.max_pool2d_with_indices.default --max-samples 1 --start-idx 4
on https://github.com/pytorch/pytorch/pull/136809
So it takes 32 minutes... but still successfully compiles? Interesting. Maybe there's a lurking pass with exponential complexity for this example.
Yeah, it finishes and runs correctly.
Looks like it's not compilation proper, but rather the anderson autoscheduler getting stuck enumerating a combinatorial number of tiling options, which is a bit absurd given that this entire pipeline seems to be elementwise other than accesses to the input buffer.
A workaround would be to ask the autoscheduler to do a lot less by generating an Expr instead of a Func for anything that has no update definition and is either consumed elementwise or is an op that is cheaper than a load (e.g. tmp48).