AMDMIGraphX
AMDMIGraphX copied to clipboard
BF16 fused_reduce compile fail
Repro
# fuse_reduce.py
import numpy as np
import migraphx
p = migraphx.program()
m = p.get_main_module()
s1 = migraphx.shape(type="float_type", lens=[1, 24, 4608, 128])
x0 = m.add_parameter("x0", s1)
x1 = m.add_parameter("x1", s1)
c1 = m.add_literal(np.array(0.0078125, dtype=np.float32))
c1 = m.add_instruction(migraphx.op("multibroadcast", out_lens=s1.lens()), [c1])
c2 = m.add_literal(np.array(2, dtype=np.float32))
c2 = m.add_instruction(migraphx.op("multibroadcast", out_lens=s1.lens()), [c2])
c4 = m.add_literal(np.array(9.98378e-07, dtype=np.float32))
pow = m.add_instruction(migraphx.op("pow"), [x0, c2])
mul = m.add_instruction(migraphx.op("mul"), [pow, c1])
red = m.add_instruction(migraphx.op("reduce_sum", axes=[3]), [mul])
c4 = m.add_instruction(migraphx.op("multibroadcast", out_lens=red.shape().lens()), [c4])
add = m.add_instruction(migraphx.op("add"), [red, c4])
rsqrt = m.add_instruction(migraphx.op("rsqrt"), [add])
rsqrt_mb = m.add_instruction(migraphx.op("multibroadcast", out_lens=s1.lens()), [rsqrt])
mul2 = m.add_instruction(migraphx.op("mul"), [x0, rsqrt_mb])
mul3 = m.add_instruction(migraphx.op("mul"), [mul2, x1])
Run:
migraphx-driver compile fuse_reduce.py --bf16
Error:
/long_pathname_so_that_rpms_can_package_the_debug_info/src/llvm-project/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp:1143: void VerifySDNode(llvm::SDNode*, const llvm::TargetLowering*): Assertion `(Op.getValueType() == EltVT || (EltVT.isInteger() && Op.getValueType().isInteger() && EltVT.bitsLE(Op.getValueType()))) && "Wrong operand type!"' failed.
It turns out that removing the last mul doesnt cause this issue. For reference, here are the produces sources for both cases:
No mul3 (compile works)
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp>
namespace migraphx {
template<class Tx0>
__device__ __attribute__((const)) auto pointwise0(Tx0 x0) {
// @param:x0 -> float_type, {1}, {0}
// convert[target_type=15] -> bf16_type, {1}, {0}
auto zz1 = migraphx::convert<bf16>(migraphx::convert<bf16>(x0));
// @return -> bf16_type, {1}, {0}
auto zzreturn = zz1;
return zzreturn;
}
template<class Tx0>
__device__ __attribute__((const)) auto pointwise1(Tx0 x0) {
// @literal -> bf16_type, {1}, {0}
auto zz0 = bf16(0.0078125);
// @literal -> bf16_type, {1}, {0}
auto zz1 = bf16(2);
// @param:x0 -> bf16_type, {1}, {0}
// pow -> bf16_type, {1}, {0}
auto zz3 = migraphx::convert<bf16>(migraphx::pow(x0, zz1));
// mul -> bf16_type, {1}, {0}
auto zz4 = migraphx::convert<bf16>(zz3 * zz0);
// @return -> bf16_type, {1}, {0}
auto zzreturn = zz4;
return zzreturn;
}
template<class Tx0>
__device__ __attribute__((const)) auto pointwise2(Tx0 x0) {
// @literal -> bf16_type, {1}, {0}
auto zz0 = bf16(9.9837779998779297e-07);
// @param:x0 -> bf16_type, {1}, {0}
// add -> bf16_type, {1}, {0}
auto zz2 = migraphx::convert<bf16>(x0 + zz0);
// rsqrt -> bf16_type, {1}, {0}
auto zz3 = migraphx::convert<bf16>(migraphx::rsqrt(zz2));
// @return -> bf16_type, {1}, {0}
auto zzreturn = zz3;
return zzreturn;
}
template<class Tx0, class Tx1>
__device__ __attribute__((const)) auto pointwise3(Tx0 x0,Tx1 x1) {
// @param:x1 -> bf16_type, {1}, {0}
// @param:x0 -> bf16_type, {1}, {0}
// mul -> bf16_type, {1}, {0}
auto zz2 = migraphx::convert<bf16>(x0 * x1);
// convert[target_type=2] -> float_type, {1}, {0}
auto zz3 = migraphx::convert<float>(migraphx::convert<float>(zz2));
// @return -> float_type, {1}, {0}
auto zzreturn = zz3;
return zzreturn;
}
template<class Tx0, class Tr, class Tout_idx>
__device__ __attribute__((const)) auto fused_reduce_op(Tx0 x0,Tr r,Tout_idx out_idx) {
(void)out_idx;
// @param:x0 -> float_type, {1, 24, 4608, 128}, {14155776, 589824, 128, 1}
// pointwise -> bf16_type, {1, 24, 4608, 128}, {14155776, 589824, 128, 1}
auto zz1 = r.inner([=](auto x0_lambda_param) { return pointwise0(x0_lambda_param); })(x0);
// pointwise -> bf16_type, {1, 24, 4608, 128}, {14155776, 589824, 128, 1}
auto zz2 = r.lazy_inner([=](auto zz1_lambda_param) { return pointwise1(zz1_lambda_param); })(zz1);
// reduce_sum[axes={3}] -> bf16_type, {1, 24, 4608, 1}, {110592, 4608, 1, 1}
auto zz3 = op::id{}(r.reduce(op::sum{}, 0, op::id{})(zz2));
// pointwise -> bf16_type, {1, 24, 4608, 1}, {110592, 4608, 1, 1}
auto zz4 = pointwise2(zz3);
// multibroadcast[out_lens={1, 24, 4608, 128},out_dyn_dims={}] -> bf16_type, {1, 24, 4608, 128}, {110592, 4608, 1, 0}
auto zz5 = zz4;
// pointwise -> float_type, {1, 24, 4608, 128}, {14155776, 589824, 128, 1}
auto zz6 = r.inner([=](auto zz1_lambda_param) { return pointwise3(zz1_lambda_param, zz5); })(zz1);
// @return -> float_type, {1, 24, 4608, 128}, {14155776, 589824, 128, 1}
auto zzreturn = make_tuple(zz6);
return zzreturn;
}
extern "C" {
MIGRAPHX_GLOBAL void convert_pow_mul_reduce_sum_add_rsqrt_mul_convert_kernel(void * private_p0,void * private_p1)
{
transform_args(make_tensors(), vectorize<4, 1>(), rotate_and_pack_last<1>())(private_p0,private_p1)([](auto y, auto... xs) {
fused_reduce<reduce::block, decltype(make_shape(index_ints<110592, 1>{}, index_ints<1, 1>{}))>(y, assign_none{}, partial(MIGRAPHX_LIFT(fused_reduce_op))(xs...));
});
}
}
} // namespace migraphx
With mul3 (compile fails):
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp>
namespace migraphx {
template<class Tx0>
__device__ __attribute__((const)) auto pointwise0(Tx0 x0) {
// @param:x0 -> float_type, {1}, {0}
// convert[target_type=15] -> bf16_type, {1}, {0}
auto zz1 = migraphx::convert<bf16>(migraphx::convert<bf16>(x0));
// @return -> bf16_type, {1}, {0}
auto zzreturn = zz1;
return zzreturn;
}
template<class Tx0>
__device__ __attribute__((const)) auto pointwise1(Tx0 x0) {
// @literal -> bf16_type, {1}, {0}
auto zz0 = bf16(0.0078125);
// @literal -> bf16_type, {1}, {0}
auto zz1 = bf16(2);
// @param:x0 -> bf16_type, {1}, {0}
// pow -> bf16_type, {1}, {0}
auto zz3 = migraphx::convert<bf16>(migraphx::pow(x0, zz1));
// mul -> bf16_type, {1}, {0}
auto zz4 = migraphx::convert<bf16>(zz3 * zz0);
// @return -> bf16_type, {1}, {0}
auto zzreturn = zz4;
return zzreturn;
}
template<class Tx0>
__device__ __attribute__((const)) auto pointwise2(Tx0 x0) {
// @literal -> bf16_type, {1}, {0}
auto zz0 = bf16(9.9837779998779297e-07);
// @param:x0 -> bf16_type, {1}, {0}
// add -> bf16_type, {1}, {0}
auto zz2 = migraphx::convert<bf16>(x0 + zz0);
// rsqrt -> bf16_type, {1}, {0}
auto zz3 = migraphx::convert<bf16>(migraphx::rsqrt(zz2));
// @return -> bf16_type, {1}, {0}
auto zzreturn = zz3;
return zzreturn;
}
template<class Tx0, class Tx1, class Tx2>
__device__ __attribute__((const)) auto pointwise3(Tx0 x0,Tx1 x1,Tx2 x2) {
// @param:x0 -> float_type, {1}, {0}
// convert[target_type=15] -> bf16_type, {1}, {0}
auto zz1 = migraphx::convert<bf16>(migraphx::convert<bf16>(x0));
// @param:x2 -> bf16_type, {1}, {0}
// @param:x1 -> bf16_type, {1}, {0}
// mul -> bf16_type, {1}, {0}
auto zz4 = migraphx::convert<bf16>(x1 * x2);
// mul -> bf16_type, {1}, {0}
auto zz5 = migraphx::convert<bf16>(zz4 * zz1);
// convert[target_type=2] -> float_type, {1}, {0}
auto zz6 = migraphx::convert<float>(migraphx::convert<float>(zz5));
// @return -> float_type, {1}, {0}
auto zzreturn = zz6;
return zzreturn;
}
template<class Tx0, class Tx1, class Tr, class Tout_idx>
__device__ __attribute__((const)) auto fused_reduce_op(Tx0 x0,Tx1 x1,Tr r,Tout_idx out_idx) {
(void)out_idx;
// @param:x0 -> float_type, {1, 24, 4608, 128}, {14155776, 589824, 128, 1}
// pointwise -> bf16_type, {1, 24, 4608, 128}, {14155776, 589824, 128, 1}
auto zz1 = r.inner([=](auto x0_lambda_param) { return pointwise0(x0_lambda_param); })(x0);
// pointwise -> bf16_type, {1, 24, 4608, 128}, {14155776, 589824, 128, 1}
auto zz2 = r.lazy_inner([=](auto zz1_lambda_param) { return pointwise1(zz1_lambda_param); })(zz1);
// reduce_sum[axes={3}] -> bf16_type, {1, 24, 4608, 1}, {110592, 4608, 1, 1}
auto zz3 = op::id{}(r.reduce(op::sum{}, 0, op::id{})(zz2));
// pointwise -> bf16_type, {1, 24, 4608, 1}, {110592, 4608, 1, 1}
auto zz4 = pointwise2(zz3);
// multibroadcast[out_lens={1, 24, 4608, 128},out_dyn_dims={}] -> bf16_type, {1, 24, 4608, 128}, {110592, 4608, 1, 0}
auto zz5 = zz4;
// @param:x1 -> float_type, {1, 24, 4608, 128}, {14155776, 589824, 128, 1}
// pointwise -> float_type, {1, 24, 4608, 128}, {14155776, 589824, 128, 1}
auto zz7 = r.inner([=](auto x1_lambda_param, auto zz1_lambda_param) { return pointwise3(x1_lambda_param, zz1_lambda_param, zz5); })(x1, zz1);
// @return -> float_type, {1, 24, 4608, 128}, {14155776, 589824, 128, 1}
auto zzreturn = make_tuple(zz7);
return zzreturn;
}
extern "C" {
MIGRAPHX_GLOBAL void convert_pow_mul_reduce_sum_add_rsqrt_convert_mul_mul_convert_kernel(void * private_p0,void * private_p1,void * private_p2)
{
transform_args(make_tensors(), vectorize<4, 1>(), rotate_and_pack_last<1>())(private_p0,private_p1,private_p2)([](auto y, auto... xs) {
fused_reduce<reduce::block, decltype(make_shape(index_ints<110592, 1>{}, index_ints<1, 1>{}))>(y, assign_none{}, partial(MIGRAPHX_LIFT(fused_reduce_op))(xs...));
});
}
}
} // namespace migraphx
What is the compiler error? I dont see any error.
The error is
/long_pathname_so_that_rpms_can_package_the_debug_info/src/llvm-project/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp:1143: void VerifySDNode(llvm::SDNode*, const llvm::TargetLowering*): Assertion `(Op.getValueType() == EltVT || (EltVT.isInteger() && Op.getValueType().isInteger() && EltVT.bitsLE(Op.getValueType()))) && "Wrong operand type!"' failed.
Is this using our 6.3 docker? Because I cant repro this on develop.