CINN
CINN copied to clipboard
Given some fusion examples.
构造一个示例来说明当前的融合Kernel不支持多输出。构造了2个对比网络。
1. 中间结果不被外部引用
x0 = builder.create_input(Float(32), [32, 32], "x0")
x1 = builder.create_input(Float(32), [32, 32], "x1")
y0 = builder.elementwise_add(x0, x1, axis=-1)
y1 = builder.relu(y0)
生成的CUDA Kernel如下:
__global__
void fn_const_scalar_1_elementwise_add_0_broadcast_to_2_max_3_fused_kernel(const float* __restrict__ x0, const float* __restrict__ x1, float* __restrict__ max_Out)
{
if ((threadIdx.x < 1024)) {
max_Out[threadIdx.x] = cinn_nvgpu_max_fp32((x0[threadIdx.x] + x1[threadIdx.x]), 0);
};
}
现象:add+relu组成的网络可以融合
2. 中间结果需要被外部引用
训练场景中,前向子图的计算的中间结果,反向很可能需要使用。构造该网络结构,主要用以模拟训练场景。
x0 = builder.create_input(Float(32), [32, 32], "x0")
x1 = builder.create_input(Float(32), [32, 32], "x1")
y0 = builder.elementwise_add(x0, x1, axis=-1)
y1 = builder.relu(y0)
# Insert an op that cannot be fused.
y2 = builder.reduce_sum(y0, dim=[0, 1])
生成的CUDA Kernel如下:
__global__
void fn_elementwise_add_0_kernel(const float* __restrict__ x0, const float* __restrict__ x1, float* __restrict__ elementwise_add_Out)
{
if ((threadIdx.x < 1024)) {
elementwise_add_Out[threadIdx.x] = (x0[threadIdx.x] + x1[threadIdx.x]);
};
}
__global__
void fn_reduce_sum_4_kernel(const float* __restrict__ var_1, float* __restrict__ reduce_sum_out)
{
float* reduce_sum_out__reduce_init = reduce_sum_out;
reduce_sum_out__reduce_init[0] = 0;
for (int32_t kk = 0; kk < 32; kk += 1) {
for (int32_t kk_0 = 0; kk_0 < 32; kk_0 += 1) {
reduce_sum_out[0] = (reduce_sum_out[0] + var_1[((32 * kk) + kk_0)]);
};
};
}
__global__
void fn_const_scalar_1_broadcast_to_2_max_3_fused_kernel(const float* __restrict__ var_1, float* __restrict__ max_Out)
{
if ((threadIdx.x < 1024)) {
max_Out[threadIdx.x] = cinn_nvgpu_max_fp32(var_1[threadIdx.x], 0);
};
}
现象:由于add的计算结果y0,需要被后续算子reduce_sum使用,导致add+relu不能融合。
若融合Kernel支持多输出,生成代码应如下:
__global__
void fn_const_scalar_1_elementwise_add_0_broadcast_to_2_max_3_fused_kernel(const float* __restrict__ x0, const float* __restrict__ x1, float* __restrict__ elementwise_add_Out, float* __restrict__ max_Out)
{
if ((threadIdx.x < 1024)) {
elementwise_add_Out[threadIdx.x] = (x0[threadIdx.x] + x1[threadIdx.x]);
max_Out[threadIdx.x] = cinn_nvgpu_max_fp32((x0[threadIdx.x] + x1[threadIdx.x]), 0);
};
}
__global__
void fn_reduce_sum_4_kernel(const float* __restrict__ var_1, float* __restrict__ reduce_sum_out)
{
float* reduce_sum_out__reduce_init = reduce_sum_out;
reduce_sum_out__reduce_init[0] = 0;
for (int32_t kk = 0; kk < 32; kk += 1) {
for (int32_t kk_0 = 0; kk_0 < 32; kk_0 += 1) {
reduce_sum_out[0] = (reduce_sum_out[0] + var_1[((32 * kk) + kk_0)]);
};
};
}
Thanks for your contribution!
两个相同配置的reduce_sum融合
- Program
var_0, var_1 = reduce_sum(x, dim=[1], keep_dim=false)
var_2 = identity(x)
var_3 = elementwise_mul(x, var_2, axis=-1)
var_4, var_5 = reduce_sum(var_3, dim=[1], keep_dim=false)
- 当前X86生成代码如下
I1124 04:51:11.166685 29146 graph_compiler.cc:622] [X86] C Code is:
#include <cinn_runtime.h>
#include <stdio.h>
void fn_reduce_sum_0(void* _args, int32_t num_args)
{
const cinn_buffer_t* _x = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0]));
cinn_buffer_t* _reduce_sum_out = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1]));
cinn_buffer_malloc((void*)(0), _reduce_sum_out);
float* reduce_sum_out = ((float*)(_reduce_sum_out->memory));
float* reduce_sum_out__reduce_init = ((float*)(_reduce_sum_out->memory));
const float* x = ((const float*)(_x->memory));
for (int32_t i = 0; i < 32; i += 1) {
for (int32_t j = 0; j < 32; j += 1) {
reduce_sum_out__reduce_init[((32 * i) + j)] = 0;
for (int32_t kk = 0; kk < 32; kk += 1) {
reduce_sum_out[((32 * i) + j)] = (reduce_sum_out[((32 * i) + j)] + x[((1024 * i) + ((32 * kk) + j))]);
};
};
};
cinn_buffer_free((void*)(0), _reduce_sum_out);
}
void fn_identity_1_elementwise_mul_2_reduce_sum_3_fused(void* _args, int32_t num_args)
{
const cinn_buffer_t* _x = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0]));
cinn_buffer_t* _reduce_sum_out_0 = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1]));
cinn_buffer_malloc((void*)(0), _reduce_sum_out_0);
float* reduce_sum_out_0 = ((float*)(_reduce_sum_out_0->memory));
float* reduce_sum_out_0__reduce_init = ((float*)(_reduce_sum_out_0->memory));
const float* x = ((const float*)(_x->memory));
for (int32_t i = 0; i < 32; i += 1) {
for (int32_t j = 0; j < 32; j += 1) {
reduce_sum_out_0__reduce_init[((32 * i) + j)] = 0;
for (int32_t kk_0 = 0; kk_0 < 32; kk_0 += 1) {
reduce_sum_out_0[((32 * i) + j)] = fma(x[((1024 * i) + ((32 * kk_0) + j))], x[((1024 * i) + ((32 * kk_0) + j))], reduce_sum_out_0[((32 * i) + j)]);
};
};
};
cinn_buffer_free((void*)(0), _reduce_sum_out_0);
}
- 当前生成的CUDA代码
I1124 04:58:30.537396 33664 compiler.cc:80] [CUDA] source code:
extern "C" {
#include "cinn_cuda_runtime_source.cuh"
#ifdef __CUDACC_RTC__
typedef int int32_t;
typedef char int8_t;
#endif
__global__
void __launch_bounds__(32) fn_reduce_sum_0_kernel(const float* __restrict__ x, float* __restrict__ reduce_sum_out)
{
float* reduce_sum_out__reduce_init = reduce_sum_out;
if (((int)blockIdx.x < 32)) {
if (((int)threadIdx.x < 32)) {
{
reduce_sum_out__reduce_init[((32 * (int)blockIdx.x) + (int)threadIdx.x)] = 0;
for (int32_t kk = 0; kk < 32; kk += 1) {
reduce_sum_out[((32 * (int)blockIdx.x) + (int)threadIdx.x)] = (reduce_sum_out[((32 * (int)blockIdx.x) + (int)threadIdx.x)] + x[((1024 * (int)blockIdx.x) + ((32 * kk) + (int)threadIdx.x))]);
};
}
};
};
}__global__
void __launch_bounds__(32) fn_identity_1_elementwise_mul_2_reduce_sum_3_fused_kernel(const float* __restrict__ x, float* __restrict__ reduce_sum_out_0)
{
float* reduce_sum_out_0__reduce_init = reduce_sum_out_0;
if (((int)blockIdx.x < 32)) {
if (((int)threadIdx.x < 32)) {
{
reduce_sum_out_0__reduce_init[((32 * (int)blockIdx.x) + (int)threadIdx.x)] = 0;
for (int32_t kk_0 = 0; kk_0 < 32; kk_0 += 1) {
reduce_sum_out_0[((32 * (int)blockIdx.x) + (int)threadIdx.x)] = (reduce_sum_out_0[((32 * (int)blockIdx.x) + (int)threadIdx.x)] + (x[((1024 * (int)blockIdx.x) + ((32 * kk_0) + (int)threadIdx.x))] * x[((1024 * (int)blockIdx.x) + ((32 * kk_0) + (int)threadIdx.x))]));
};
}
};
};
}