CINN icon indicating copy to clipboard operation
CINN copied to clipboard

Given some fusion examples.

Open Xreki opened this issue 3 years ago • 2 comments

构造一个示例来说明当前的融合Kernel不支持多输出。构造了2个对比网络。

1. 中间结果不被外部引用

x0 = builder.create_input(Float(32), [32, 32], "x0")
x1 = builder.create_input(Float(32), [32, 32], "x1")
y0 = builder.elementwise_add(x0, x1, axis=-1)
y1 = builder.relu(y0)

生成的CUDA Kernel如下:

__global__
void fn_const_scalar_1_elementwise_add_0_broadcast_to_2_max_3_fused_kernel(const float* __restrict__ x0, const float* __restrict__ x1, float* __restrict__ max_Out)
{
  if ((threadIdx.x < 1024)) {
    max_Out[threadIdx.x] = cinn_nvgpu_max_fp32((x0[threadIdx.x] + x1[threadIdx.x]), 0);
  };
}

现象:add+relu组成的网络可以融合

2. 中间结果需要被外部引用

训练场景中,前向子图的计算的中间结果,反向很可能需要使用。构造该网络结构,主要用以模拟训练场景。

x0 = builder.create_input(Float(32), [32, 32], "x0")
x1 = builder.create_input(Float(32), [32, 32], "x1")
y0 = builder.elementwise_add(x0, x1, axis=-1)
y1 = builder.relu(y0)
# Insert an op that cannot be fused.
y2 = builder.reduce_sum(y0, dim=[0, 1])

生成的CUDA Kernel如下:

__global__
void fn_elementwise_add_0_kernel(const float* __restrict__ x0, const float* __restrict__ x1, float* __restrict__ elementwise_add_Out)
{
  if ((threadIdx.x < 1024)) {
    elementwise_add_Out[threadIdx.x] = (x0[threadIdx.x] + x1[threadIdx.x]);
  };
}
__global__
void fn_reduce_sum_4_kernel(const float* __restrict__ var_1, float* __restrict__ reduce_sum_out)
{
  float* reduce_sum_out__reduce_init = reduce_sum_out;
  reduce_sum_out__reduce_init[0] = 0;
  for (int32_t kk = 0; kk < 32; kk += 1) {
    for (int32_t kk_0 = 0; kk_0 < 32; kk_0 += 1) {
      reduce_sum_out[0] = (reduce_sum_out[0] + var_1[((32 * kk) + kk_0)]);
    };
  };
}
__global__
void fn_const_scalar_1_broadcast_to_2_max_3_fused_kernel(const float* __restrict__ var_1, float* __restrict__ max_Out)
{
  if ((threadIdx.x < 1024)) {
    max_Out[threadIdx.x] = cinn_nvgpu_max_fp32(var_1[threadIdx.x], 0);
  };
}

现象:由于add的计算结果y0,需要被后续算子reduce_sum使用,导致add+relu不能融合。

若融合Kernel支持多输出,生成代码应如下:

__global__
void fn_const_scalar_1_elementwise_add_0_broadcast_to_2_max_3_fused_kernel(const float* __restrict__ x0, const float* __restrict__ x1, float* __restrict__ elementwise_add_Out, float* __restrict__ max_Out)
{
  if ((threadIdx.x < 1024)) {
    elementwise_add_Out[threadIdx.x] = (x0[threadIdx.x] + x1[threadIdx.x]);
    max_Out[threadIdx.x] = cinn_nvgpu_max_fp32((x0[threadIdx.x] + x1[threadIdx.x]), 0);
  };
}
__global__
void fn_reduce_sum_4_kernel(const float* __restrict__ var_1, float* __restrict__ reduce_sum_out)
{
  float* reduce_sum_out__reduce_init = reduce_sum_out;
  reduce_sum_out__reduce_init[0] = 0;
  for (int32_t kk = 0; kk < 32; kk += 1) {
    for (int32_t kk_0 = 0; kk_0 < 32; kk_0 += 1) {
      reduce_sum_out[0] = (reduce_sum_out[0] + var_1[((32 * kk) + kk_0)]);
    };
  };
}

Xreki avatar Nov 08 '21 07:11 Xreki

Thanks for your contribution!

paddle-bot-old[bot] avatar Nov 08 '21 07:11 paddle-bot-old[bot]

两个相同配置的reduce_sum融合

  • Program
var_0, var_1 = reduce_sum(x, dim=[1], keep_dim=false)
var_2 = identity(x)
var_3 = elementwise_mul(x, var_2, axis=-1)
var_4, var_5 = reduce_sum(var_3, dim=[1], keep_dim=false)
  • 当前X86生成代码如下
I1124 04:51:11.166685 29146 graph_compiler.cc:622] [X86] C Code is:
#include <cinn_runtime.h>
#include <stdio.h>

void fn_reduce_sum_0(void* _args, int32_t num_args)
{
  const cinn_buffer_t* _x = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0]));
  cinn_buffer_t* _reduce_sum_out = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1]));
  cinn_buffer_malloc((void*)(0), _reduce_sum_out);
  float* reduce_sum_out = ((float*)(_reduce_sum_out->memory));
  float* reduce_sum_out__reduce_init = ((float*)(_reduce_sum_out->memory));
  const float* x = ((const float*)(_x->memory));
  for (int32_t i = 0; i < 32; i += 1) {
    for (int32_t j = 0; j < 32; j += 1) {
      reduce_sum_out__reduce_init[((32 * i) + j)] = 0;
      for (int32_t kk = 0; kk < 32; kk += 1) {
        reduce_sum_out[((32 * i) + j)] = (reduce_sum_out[((32 * i) + j)] + x[((1024 * i) + ((32 * kk) + j))]);
      };
    };
  };
  cinn_buffer_free((void*)(0), _reduce_sum_out);
}

void fn_identity_1_elementwise_mul_2_reduce_sum_3_fused(void* _args, int32_t num_args)
{
  const cinn_buffer_t* _x = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0]));
  cinn_buffer_t* _reduce_sum_out_0 = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1]));
  cinn_buffer_malloc((void*)(0), _reduce_sum_out_0);
  float* reduce_sum_out_0 = ((float*)(_reduce_sum_out_0->memory));
  float* reduce_sum_out_0__reduce_init = ((float*)(_reduce_sum_out_0->memory));
  const float* x = ((const float*)(_x->memory));
  for (int32_t i = 0; i < 32; i += 1) {
    for (int32_t j = 0; j < 32; j += 1) {
      reduce_sum_out_0__reduce_init[((32 * i) + j)] = 0;
      for (int32_t kk_0 = 0; kk_0 < 32; kk_0 += 1) {
        reduce_sum_out_0[((32 * i) + j)] = fma(x[((1024 * i) + ((32 * kk_0) + j))], x[((1024 * i) + ((32 * kk_0) + j))], reduce_sum_out_0[((32 * i) + j)]);
      };
    };
  };
  cinn_buffer_free((void*)(0), _reduce_sum_out_0);
}
  • 当前生成的CUDA代码
I1124 04:58:30.537396 33664 compiler.cc:80] [CUDA] source code:
extern "C" {

#include "cinn_cuda_runtime_source.cuh"

#ifdef __CUDACC_RTC__
typedef int int32_t;
typedef char int8_t;
#endif



__global__
void __launch_bounds__(32) fn_reduce_sum_0_kernel(const float* __restrict__ x, float* __restrict__ reduce_sum_out)
{
  float* reduce_sum_out__reduce_init = reduce_sum_out;
  if (((int)blockIdx.x < 32)) {
    if (((int)threadIdx.x < 32)) {
    {
      reduce_sum_out__reduce_init[((32 * (int)blockIdx.x) + (int)threadIdx.x)] = 0;
      for (int32_t kk = 0; kk < 32; kk += 1) {
        reduce_sum_out[((32 * (int)blockIdx.x) + (int)threadIdx.x)] = (reduce_sum_out[((32 * (int)blockIdx.x) + (int)threadIdx.x)] + x[((1024 * (int)blockIdx.x) + ((32 * kk) + (int)threadIdx.x))]);
      };
    }
    };
  };
}__global__
void __launch_bounds__(32) fn_identity_1_elementwise_mul_2_reduce_sum_3_fused_kernel(const float* __restrict__ x, float* __restrict__ reduce_sum_out_0)
{
  float* reduce_sum_out_0__reduce_init = reduce_sum_out_0;
  if (((int)blockIdx.x < 32)) {
    if (((int)threadIdx.x < 32)) {
    {
      reduce_sum_out_0__reduce_init[((32 * (int)blockIdx.x) + (int)threadIdx.x)] = 0;
      for (int32_t kk_0 = 0; kk_0 < 32; kk_0 += 1) {
        reduce_sum_out_0[((32 * (int)blockIdx.x) + (int)threadIdx.x)] = (reduce_sum_out_0[((32 * (int)blockIdx.x) + (int)threadIdx.x)] + (x[((1024 * (int)blockIdx.x) + ((32 * kk_0) + (int)threadIdx.x))] * x[((1024 * (int)blockIdx.x) + ((32 * kk_0) + (int)threadIdx.x))]));
      };
    }
    };
  };
}

Xreki avatar Nov 24 '21 04:11 Xreki