Halide
Halide copied to clipboard
SkipStage doesn't understand mux
A mux is equivalent to a select tree, but the skip-stages pass doesn't understand that it means some values are unused. Consider the following:
#include "Halide.h"
using namespace Halide;
int main(int argc, char **argv) {
Func f, g;
Var x, c;
f(x) = sqrt(sqrt(sqrt(sqrt(sqrt(x)))));
f.compute_root();
// g(x, c) = select(c == 0, f(x), 0.f);
g(x, c) = mux(c, {f(x), 0.f});
g.reorder(x, c).vectorize(x, 16, TailStrategy::RoundUp).bound(c, 0, 2).unroll(c);
f.compute_at(g, x).vectorize(x);
g.compile_jit();
return 0;
}
f is only used for c == 0, but this code computes it in the c == 1 case too:
produce g {
let t3 = (g.extent.0 + 15)/16
for (g.s0.x.x, 0, t3) {
allocate f[float32 * 16]
produce f {
f[ramp(0, 1, 16)] = (float32x16)sqrt_f32((float32x16)sqrt_f32((float32x16)sqrt_f32((float32x16)sqrt_f32((float32x16)sqrt_f32(float32x16(ramp((g.s0.x.x*16) + g.min.0, 1, 16)))))))
}
consume f {
g[ramp(g.s0.x.x*16, 1, 16) aligned(16, 0)] = f[ramp(0, 1, 16)]
}
free f
}
let t4 = (g.extent.0 + 15)/16
for (g.s0.x.x, 0, t4) {
allocate f[float32 * 16]
produce f {
f[ramp(0, 1, 16)] = (float32x16)sqrt_f32((float32x16)sqrt_f32((float32x16)sqrt_f32((float32x16)sqrt_f32((float32x16)sqrt_f32(float32x16(ramp((g.s0.x.x*16) + g.min.0, 1, 16)))))))
}
free f
consume f {
g[ramp((g.s0.x.x*16) + g.stride.1, 1, 16)] = x16(0.000000f)
}
}
}
This extra copy of f isn't always dead-stripped by LLVM.