boda icon indicating copy to clipboard operation
boda copied to clipboard

OpenCL SGEMM Tuning for Qualcomm Snapdragon 820 / Adreno 530

Open moskewcz opened this issue 8 years ago • 13 comments

Currently, I'm limited to ~30% efficiency, and it's not clear why or what i might do to improve from there. Peak compute was roughly determined from microbenchmarks as ~256GF/s. Local memory seems hard to use -- I can get close to the non-local-mem using varaint when using local memory, but i can't even match it. It also seems like maybe the 'no local memory' variant is actually using local memory ("Scratch") to emulate extra registers (?) but it's not clear.

timing is via OpenCL events. the top-level driver for this testing is here: https://github.com/moskewcz/boda/blob/master/src/cnn-prof.cc the low-level OpenCL backend (which performs the timing) is here: https://github.com/moskewcz/boda/blob/master/src/ocl_util.cc there are other layers in between for codegen and such ...

See doc/sgemm-notes.txt for the general history and state of affairs. For reference, here's both:

  • my current fastest (~75GF/s) variant for 512x512, no local mem
  • with-local-mem version (~60GF/s)
typedef unsigned uint32_t;
__constant uint32_t const U32_MAX = 0xffffffff;
typedef int int32_t;
#72 "out_0.cl"
kernel void sgemm_simd__K_512__M_512__N_512__Mg_4__Ng_4__Mb_16__Nb_16__Kb_1__Mt_8__Nt_8__prof_variant_0__use_local_mem_2__vw_8( global float const * const a,
       global float const * const b,
       global float * const c )

{
  float c_r[8*8] = {0};
  float8 a_r[8/8];
  float8 b_r[8/8];

  int const a_off_thr = ( (get_group_id(0)/4)*16 + (get_local_id(0)/16) )*8/8*1;
  int const b_off_thr = ( (get_group_id(0)%4)*16 + (get_local_id(0)%16) )*8/8*1;

  int32_t a_off = a_off_thr;
  int32_t b_off = b_off_thr;
  for( int32_t k = 0; k < 512; k += 1 ) {

   a_r[0] = ((global float8 const *)a)[a_off+0];
   b_r[0] = ((global float8 const *)b)[b_off+0];
   c_r[0] += a_r[0].s0*b_r[0].s0;
   c_r[1] += a_r[0].s0*b_r[0].s1;
   c_r[2] += a_r[0].s0*b_r[0].s2;
   c_r[3] += a_r[0].s0*b_r[0].s3;
   c_r[4] += a_r[0].s0*b_r[0].s4;
   c_r[5] += a_r[0].s0*b_r[0].s5;
   c_r[6] += a_r[0].s0*b_r[0].s6;
   c_r[7] += a_r[0].s0*b_r[0].s7;
   c_r[8] += a_r[0].s1*b_r[0].s0;
   c_r[9] += a_r[0].s1*b_r[0].s1;
   c_r[10] += a_r[0].s1*b_r[0].s2;
   c_r[11] += a_r[0].s1*b_r[0].s3;
   c_r[12] += a_r[0].s1*b_r[0].s4;
   c_r[13] += a_r[0].s1*b_r[0].s5;
   c_r[14] += a_r[0].s1*b_r[0].s6;
   c_r[15] += a_r[0].s1*b_r[0].s7;
   c_r[16] += a_r[0].s2*b_r[0].s0;
   c_r[17] += a_r[0].s2*b_r[0].s1;
   c_r[18] += a_r[0].s2*b_r[0].s2;
   c_r[19] += a_r[0].s2*b_r[0].s3;
   c_r[20] += a_r[0].s2*b_r[0].s4;
   c_r[21] += a_r[0].s2*b_r[0].s5;
   c_r[22] += a_r[0].s2*b_r[0].s6;
   c_r[23] += a_r[0].s2*b_r[0].s7;
   c_r[24] += a_r[0].s3*b_r[0].s0;
   c_r[25] += a_r[0].s3*b_r[0].s1;
   c_r[26] += a_r[0].s3*b_r[0].s2;
   c_r[27] += a_r[0].s3*b_r[0].s3;
   c_r[28] += a_r[0].s3*b_r[0].s4;
   c_r[29] += a_r[0].s3*b_r[0].s5;
   c_r[30] += a_r[0].s3*b_r[0].s6;
   c_r[31] += a_r[0].s3*b_r[0].s7;
   c_r[32] += a_r[0].s4*b_r[0].s0;
   c_r[33] += a_r[0].s4*b_r[0].s1;
   c_r[34] += a_r[0].s4*b_r[0].s2;
   c_r[35] += a_r[0].s4*b_r[0].s3;
   c_r[36] += a_r[0].s4*b_r[0].s4;
   c_r[37] += a_r[0].s4*b_r[0].s5;
   c_r[38] += a_r[0].s4*b_r[0].s6;
   c_r[39] += a_r[0].s4*b_r[0].s7;
   c_r[40] += a_r[0].s5*b_r[0].s0;
   c_r[41] += a_r[0].s5*b_r[0].s1;
   c_r[42] += a_r[0].s5*b_r[0].s2;
   c_r[43] += a_r[0].s5*b_r[0].s3;
   c_r[44] += a_r[0].s5*b_r[0].s4;
   c_r[45] += a_r[0].s5*b_r[0].s5;
   c_r[46] += a_r[0].s5*b_r[0].s6;
   c_r[47] += a_r[0].s5*b_r[0].s7;
   c_r[48] += a_r[0].s6*b_r[0].s0;
   c_r[49] += a_r[0].s6*b_r[0].s1;
   c_r[50] += a_r[0].s6*b_r[0].s2;
   c_r[51] += a_r[0].s6*b_r[0].s3;
   c_r[52] += a_r[0].s6*b_r[0].s4;
   c_r[53] += a_r[0].s6*b_r[0].s5;
   c_r[54] += a_r[0].s6*b_r[0].s6;
   c_r[55] += a_r[0].s6*b_r[0].s7;
   c_r[56] += a_r[0].s7*b_r[0].s0;
   c_r[57] += a_r[0].s7*b_r[0].s1;
   c_r[58] += a_r[0].s7*b_r[0].s2;
   c_r[59] += a_r[0].s7*b_r[0].s3;
   c_r[60] += a_r[0].s7*b_r[0].s4;
   c_r[61] += a_r[0].s7*b_r[0].s5;
   c_r[62] += a_r[0].s7*b_r[0].s6;
   c_r[63] += a_r[0].s7*b_r[0].s7;

    a_off += 1*512/8;
    b_off += 1*512/8;
  }

  int32_t c_off =
    ((get_group_id(0)/4)*16+(get_local_id(0)/16))*8*512/8 +
    ((get_group_id(0)%4)*16+(get_local_id(0)%16))*8*1/8;

  for( int32_t Mt = 0; Mt < 8; ++Mt ) {

   switch(Mt) {
   case 0:
   b_r[0].s0 = c_r[0];
   b_r[0].s1 = c_r[1];
   b_r[0].s2 = c_r[2];
   b_r[0].s3 = c_r[3];
   b_r[0].s4 = c_r[4];
   b_r[0].s5 = c_r[5];
   b_r[0].s6 = c_r[6];
   b_r[0].s7 = c_r[7];
   break;
   case 1:
   b_r[0].s0 = c_r[8];
   b_r[0].s1 = c_r[9];
   b_r[0].s2 = c_r[10];
   b_r[0].s3 = c_r[11];
   b_r[0].s4 = c_r[12];
   b_r[0].s5 = c_r[13];
   b_r[0].s6 = c_r[14];
   b_r[0].s7 = c_r[15];
   break;
   case 2:
   b_r[0].s0 = c_r[16];
   b_r[0].s1 = c_r[17];
   b_r[0].s2 = c_r[18];
   b_r[0].s3 = c_r[19];
   b_r[0].s4 = c_r[20];
   b_r[0].s5 = c_r[21];
   b_r[0].s6 = c_r[22];
   b_r[0].s7 = c_r[23];
   break;
   case 3:
   b_r[0].s0 = c_r[24];
   b_r[0].s1 = c_r[25];
   b_r[0].s2 = c_r[26];
   b_r[0].s3 = c_r[27];
   b_r[0].s4 = c_r[28];
   b_r[0].s5 = c_r[29];
   b_r[0].s6 = c_r[30];
   b_r[0].s7 = c_r[31];
   break;
   case 4:
   b_r[0].s0 = c_r[32];
   b_r[0].s1 = c_r[33];
   b_r[0].s2 = c_r[34];
   b_r[0].s3 = c_r[35];
   b_r[0].s4 = c_r[36];
   b_r[0].s5 = c_r[37];
   b_r[0].s6 = c_r[38];
   b_r[0].s7 = c_r[39];
   break;
   case 5:
   b_r[0].s0 = c_r[40];
   b_r[0].s1 = c_r[41];
   b_r[0].s2 = c_r[42];
   b_r[0].s3 = c_r[43];
   b_r[0].s4 = c_r[44];
   b_r[0].s5 = c_r[45];
   b_r[0].s6 = c_r[46];
   b_r[0].s7 = c_r[47];
   break;
   case 6:
   b_r[0].s0 = c_r[48];
   b_r[0].s1 = c_r[49];
   b_r[0].s2 = c_r[50];
   b_r[0].s3 = c_r[51];
   b_r[0].s4 = c_r[52];
   b_r[0].s5 = c_r[53];
   b_r[0].s6 = c_r[54];
   b_r[0].s7 = c_r[55];
   break;
   case 7:
   b_r[0].s0 = c_r[56];
   b_r[0].s1 = c_r[57];
   b_r[0].s2 = c_r[58];
   b_r[0].s3 = c_r[59];
   b_r[0].s4 = c_r[60];
   b_r[0].s5 = c_r[61];
   b_r[0].s6 = c_r[62];
   b_r[0].s7 = c_r[63];
   break;
   }


   ((global float8 *)c)[c_off+0] = b_r[0];

    c_off += 512/8;
  }

}

and similarly, the with-local-mem version (~60GF/s):

typedef unsigned uint32_t;
__constant uint32_t const U32_MAX = 0xffffffff;
typedef int int32_t;
#74 "out_0.cl"
kernel void sgemm_simd_local__K_512__M_512__N_512__Mg_8__Ng_8__Mb_16__Nb_16__Kb_4__Mt_4__Nt_4__prof_variant_0__use_local_mem_3__vw_4( global float const * const a,
       global float const * const b,
       global float * const c )

{

  local float4 a_sm[64];
  local float4 b_sm[64];

  float c_r[4*4] = {0};
  float4 a_r[4/4];
  float4 b_r[4/4];

  int32_t a_off = (get_group_id(0)/8)*16*4*1/4 + get_local_id(0);
  int32_t b_off = (get_group_id(0)%8)*16*4*1/4 + get_local_id(0);

  local float4 * const a_sm_off = a_sm + (get_local_id(0)/16)*4/4;
  local float4 * const b_sm_off = b_sm + (get_local_id(0)%16)*4/4;

  for( int32_t k = 0; k < 512; k += 4 ) {

   if( (get_local_id(0)+0) < 64 ) {
   a_sm[get_local_id(0)+0] = ((global float4 const *)(a))[a_off+0+(get_local_id(0)+0)/16*112];}
   if( (get_local_id(0)+0) < 64 ) {
   b_sm[get_local_id(0)+0] = ((global float4 const *)(b))[b_off+0+(get_local_id(0)+0)/16*112];}

    barrier(CLK_LOCAL_MEM_FENCE);

   a_r[0] = a_sm_off[0];
   b_r[0] = b_sm_off[0];
   c_r[0] += a_r[0].s0*b_r[0].s0;
   c_r[1] += a_r[0].s0*b_r[0].s1;
   c_r[2] += a_r[0].s0*b_r[0].s2;
   c_r[3] += a_r[0].s0*b_r[0].s3;
   c_r[4] += a_r[0].s1*b_r[0].s0;
   c_r[5] += a_r[0].s1*b_r[0].s1;
   c_r[6] += a_r[0].s1*b_r[0].s2;
   c_r[7] += a_r[0].s1*b_r[0].s3;
   c_r[8] += a_r[0].s2*b_r[0].s0;
   c_r[9] += a_r[0].s2*b_r[0].s1;
   c_r[10] += a_r[0].s2*b_r[0].s2;
   c_r[11] += a_r[0].s2*b_r[0].s3;
   c_r[12] += a_r[0].s3*b_r[0].s0;
   c_r[13] += a_r[0].s3*b_r[0].s1;
   c_r[14] += a_r[0].s3*b_r[0].s2;
   c_r[15] += a_r[0].s3*b_r[0].s3;
   a_r[0] = a_sm_off[16];
   b_r[0] = b_sm_off[16];
   c_r[0] += a_r[0].s0*b_r[0].s0;
   c_r[1] += a_r[0].s0*b_r[0].s1;
   c_r[2] += a_r[0].s0*b_r[0].s2;
   c_r[3] += a_r[0].s0*b_r[0].s3;
   c_r[4] += a_r[0].s1*b_r[0].s0;
   c_r[5] += a_r[0].s1*b_r[0].s1;
   c_r[6] += a_r[0].s1*b_r[0].s2;
   c_r[7] += a_r[0].s1*b_r[0].s3;
   c_r[8] += a_r[0].s2*b_r[0].s0;
   c_r[9] += a_r[0].s2*b_r[0].s1;
   c_r[10] += a_r[0].s2*b_r[0].s2;
   c_r[11] += a_r[0].s2*b_r[0].s3;
   c_r[12] += a_r[0].s3*b_r[0].s0;
   c_r[13] += a_r[0].s3*b_r[0].s1;
   c_r[14] += a_r[0].s3*b_r[0].s2;
   c_r[15] += a_r[0].s3*b_r[0].s3;
   a_r[0] = a_sm_off[32];
   b_r[0] = b_sm_off[32];
   c_r[0] += a_r[0].s0*b_r[0].s0;
   c_r[1] += a_r[0].s0*b_r[0].s1;
   c_r[2] += a_r[0].s0*b_r[0].s2;
   c_r[3] += a_r[0].s0*b_r[0].s3;
   c_r[4] += a_r[0].s1*b_r[0].s0;
   c_r[5] += a_r[0].s1*b_r[0].s1;
   c_r[6] += a_r[0].s1*b_r[0].s2;
   c_r[7] += a_r[0].s1*b_r[0].s3;
   c_r[8] += a_r[0].s2*b_r[0].s0;
   c_r[9] += a_r[0].s2*b_r[0].s1;
   c_r[10] += a_r[0].s2*b_r[0].s2;
   c_r[11] += a_r[0].s2*b_r[0].s3;
   c_r[12] += a_r[0].s3*b_r[0].s0;
   c_r[13] += a_r[0].s3*b_r[0].s1;
   c_r[14] += a_r[0].s3*b_r[0].s2;
   c_r[15] += a_r[0].s3*b_r[0].s3;
   a_r[0] = a_sm_off[48];
   b_r[0] = b_sm_off[48];
   c_r[0] += a_r[0].s0*b_r[0].s0;
   c_r[1] += a_r[0].s0*b_r[0].s1;
   c_r[2] += a_r[0].s0*b_r[0].s2;
   c_r[3] += a_r[0].s0*b_r[0].s3;
   c_r[4] += a_r[0].s1*b_r[0].s0;
   c_r[5] += a_r[0].s1*b_r[0].s1;
   c_r[6] += a_r[0].s1*b_r[0].s2;
   c_r[7] += a_r[0].s1*b_r[0].s3;
   c_r[8] += a_r[0].s2*b_r[0].s0;
   c_r[9] += a_r[0].s2*b_r[0].s1;
   c_r[10] += a_r[0].s2*b_r[0].s2;
   c_r[11] += a_r[0].s2*b_r[0].s3;
   c_r[12] += a_r[0].s3*b_r[0].s0;
   c_r[13] += a_r[0].s3*b_r[0].s1;
   c_r[14] += a_r[0].s3*b_r[0].s2;
   c_r[15] += a_r[0].s3*b_r[0].s3;

    a_off += 4*512/4;
    b_off += 4*512/4;
    barrier(CLK_LOCAL_MEM_FENCE);
  }

  int32_t c_off =
    ((get_group_id(0)/8)*16+(get_local_id(0)/16))*4*512/4 +
    ((get_group_id(0)%8)*16+(get_local_id(0)%16))*4*1/4;

  for( int32_t Mt = 0; Mt < 4; ++Mt ) {

   switch(Mt) {
   case 0:
   b_r[0].s0 = c_r[0];
   b_r[0].s1 = c_r[1];
   b_r[0].s2 = c_r[2];
   b_r[0].s3 = c_r[3];
   break;
   case 1:
   b_r[0].s0 = c_r[4];
   b_r[0].s1 = c_r[5];
   b_r[0].s2 = c_r[6];
   b_r[0].s3 = c_r[7];
   break;
   case 2:
   b_r[0].s0 = c_r[8];
   b_r[0].s1 = c_r[9];
   b_r[0].s2 = c_r[10];
   b_r[0].s3 = c_r[11];
   break;
   case 3:
   b_r[0].s0 = c_r[12];
   b_r[0].s1 = c_r[13];
   b_r[0].s2 = c_r[14];
   b_r[0].s3 = c_r[15];
   break;
   }


   ((global float4 *)c)[c_off+0] = b_r[0];

    c_off += 512/4;
  }

}

moskewcz avatar May 05 '16 21:05 moskewcz

Hi,

The 'scratch' you mentioned refers to the "stack memory space", which is allocated in the "Private Memory".

As can be seen in your kernel source code, each work item has

float c_r[8*8] = {0}; float8 a_r[8/8]; float8 b_r[8/8];

They are "private variables" and compiler will generate bunch of load/store instruction on the private memory, which is the performance bottleneck I believe.

I guess you can somehow shrink "float c_r[8*8]" down to "float c_r[8]", then the compiler could optimize and move it into the Local Memory for all work items, which should be faster than the Private Memory access.

Or, you can directly use "local xxx" to declare it as the local variables.

Please keep in mind that there is a hardware limitation about the total size of the Local Memory, normally it is not big. OpenCL Spec minimum requirement is 32K bytes I believe.

In a word, your code uses lots of private variables requesting lots of memory access on the private memory. That is the performance bottleneck.

ncybmh99 avatar May 24 '16 19:05 ncybmh99

Hello, I'm working on hgemm(fp16) in Adreno 530 too recently. I found:

  1. It's faster when we don't use local mem compared with using local mem, which supports your experiments.
  2. Ways of utilizing L1 cache read, such as using image as input, DO improve performance. I achieve 80GFlops/s in m=n=k=1024 using fp16. You can try in this program, too.

I'll be working on your fp32 75GFlops kernel, and try to use fp16 to achieve about 120GFlops.

Ps:

  1. I test the fp32 75GFlops kernel and get the real performance about 80GFlops . But when I simply replace float with half to get an efficient fp16 kernel, the performance seems not to improve(82GFlops). However, in the view of throughput, fp16 is double of fp32. So we are supposed to get a fp16 kernel 1.5~1.8x faster than fp32 kernel. It's very confusing.

Thu-Chris avatar Feb 24 '18 01:02 Thu-Chris

sounds interesting. often times it's not simple/obvious how to actually get fp16 to give perf. gains -- i'm interested to know what you learn/figure out. i did some work on supporting half precision, but i don't think i went too far with it; i think for my initial experiments i only supported load/store (not compute) being 16-bit, since that's all that opencl supports without extensions -- that is, without exts, you can't actually do 16-bit float operations (mul/add/etc). you can see some of the primitive support in the sgemm code generators, where they use vstore_half() and vload_half() when the input/output data types are 16-bit. it did compute correct s/hgemm results from what i remember, i forget if the perf was much different from normal 32-bit mode -- if anything you might expect it to be slower if the load/store half operations add overhead. i suppose one might also expect some possible speedup due to less memory access as well. https://github.com/moskewcz/boda/blob/master/src/cnn_codegen.cc#L457

moskewcz avatar Mar 02 '18 20:03 moskewcz

@Thu-Chris "I simply replace float with half to get an efficient fp16 kernel" I also made some experiments, at first I repeat result in ~3.6ms in FP32. Then I replace 1D grid with 2D(just for convenience). To get good result in fp16 I replace all float to half and use local size 32x32 instead 16x16(256 in 1D case). And then I get ~2.5ms against ~3.6. But I don't like to make accumulators (c_r) fp16. And I don't know is it possible to preserve c_r in fp32 for HGEMM and keep performance.

roserg avatar Mar 03 '18 10:03 roserg

@moskewcz @roserg What is the optimum global & local size did you use to reproduce the resuts for FP 32? I would like to check the performance on Adreno 330. I'm able to reproduce the results on an 1080 Ti using global_size[]={4096} & local_size[]={256}.

I would like to hear your opinion about modifying this kernel for rectangular matrices (something like m-16, n-200704, k-27). Has anyone tried such sizes? Any pointer on how to do that will be helpful.

sivagnanamn avatar Mar 09 '18 05:03 sivagnanamn

@moskewcz Thanks for replying. I have achieved about 15ms in fp16 gemm(m=n=k=1024)with the kernel below. It's about 143GFlops and modified from your fp32 kernel.

#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void gemm( global const half* a,
       global const half* b,
       global half* c )
{
  half8 c_r[8] = {0,0,0,0,0,0,0,0};
  half8 a_r;
  half8 b_r;

  int const a_off_thr = get_global_id(0)/128;
  int const b_off_thr = get_global_id(0)%128;

  int a_off = a_off_thr;
  int b_off = b_off_thr;
  for( int k = 0; k < 1024; k += 1 ) {
    a_r = ((global const half8*)a)[a_off+0];
    b_r = ((global const half8*)b)[b_off+0];
    c_r[0] += a_r.s0*b_r;
    c_r[1] += a_r.s1*b_r;
    c_r[2] += a_r.s2*b_r;
    c_r[3] += a_r.s3*b_r;
    c_r[4] += a_r.s4*b_r;
    c_r[5] += a_r.s5*b_r;
    c_r[6] += a_r.s6*b_r;
    c_r[7] += a_r.s7*b_r;
    a_off += 128;
    b_off += 128;
  }
  int c_off = (get_global_id(0)/128)*1024*8 + (get_global_id(0)%128)*8;
  vstore8(c_r[0], 0, c+c_off);
  vstore8(c_r[1], 0, c+c_off+1024);
  vstore8(c_r[2], 0, c+c_off+1024*2);
  vstore8(c_r[3], 0, c+c_off+1024*3);
  vstore8(c_r[4], 0, c+c_off+1024*4);
  vstore8(c_r[5], 0, c+c_off+1024*5);
  vstore8(c_r[6], 0, c+c_off+1024*6);
  vstore8(c_r[7], 0, c+c_off+1024*7);
}

global = {16384} (128128) local = {256} (1616)

It's the best performance I ever got in Adreno 530. But it's very confusing when I simply changes some lines, the performance reduces much. Like below:

#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void gemm( global const half* a,
       global const half* b,
       global half* c )
{
  half8 c_r[8] = {0,0,0,0,0,0,0,0};
  half8 a_r;
  half8 b_r;

  int const a_off_thr = get_global_id(0)/128;
  int const b_off_thr = get_global_id(0)%128;

  int a_off = a_off_thr;
  int b_off = b_off_thr;
  for( int k = 0; k < 1024; k += 1 ) {
    a_r = vload8(a_off, a);   //use vload8, which should improve perf
    b_r = vload8(b_off, b);
    c_r[0] += a_r.s0*b_r;
    c_r[1] += a_r.s1*b_r;
    c_r[2] += a_r.s2*b_r;
    c_r[3] += a_r.s3*b_r;
    c_r[4] += a_r.s4*b_r;
    c_r[5] += a_r.s5*b_r;
    c_r[6] += a_r.s6*b_r;
    c_r[7] += a_r.s7*b_r;
    a_off += 128;
    b_off += 128;
  }
  int c_off = (get_global_id(0)/128)*1024*8 + (get_global_id(0)%128)*8;
  vstore8(c_r[0], 0, c+c_off);
  vstore8(c_r[1], 0, c+c_off+1024);
  vstore8(c_r[2], 0, c+c_off+1024*2);
  vstore8(c_r[3], 0, c+c_off+1024*3);
  vstore8(c_r[4], 0, c+c_off+1024*4);
  vstore8(c_r[5], 0, c+c_off+1024*5);
  vstore8(c_r[6], 0, c+c_off+1024*6);
  vstore8(c_r[7], 0, c+c_off+1024*7);
} 

The slightly modified kernel only got 75GFlops, 28.6ms, compared to 15ms before.

That's very confusing. It seems some other people also meet this kind of problem.

Thu-Chris avatar Mar 11 '18 09:03 Thu-Chris

But it's very confusing when I simply changes some lines, the performance reduces much. Like below:

@Thu-Chris Yes, I too faced similar problems when modifying certain lines of the kernel. The performance degraded drastically, when I modified @moskewcz kernel. Not sure about the root cause.

sivagnanamn avatar Mar 12 '18 00:03 sivagnanamn

@sivagnanamn I think the reason is qualcomm opencl compiler's poor design... Lines expressing same meaning are compiled to difference instructions, which gives difference performance. It's important we can get some tools like cuda ptx. Does anyone know something like this kind of assembler?

Thu-Chris avatar Mar 12 '18 02:03 Thu-Chris

@Thu-Chris Its not only Qualcomm's compiler, even ARM Mali T764 shows similar degradation.

sivagnanamn avatar Mar 12 '18 08:03 sivagnanamn

@roserg

But I don't like to make accumulators (c_r) fp16. And I don't know is it possible to preserve c_r in fp32 for HGEMM and keep performance.

Try vload_halfn first when add up to (c_r)?

Thu-Chris avatar Mar 15 '18 02:03 Thu-Chris

@Thu-Chris Any possible suggestions (other than using L1 cache - with image as input) to improve the performance?

sivagnanamn avatar May 17 '18 01:05 sivagnanamn

@sivagnanamn Here are what I got:

  1. transpose A and implement GEMM_Atrans(A_trans, B, C) is always faster than directly GEMM(A, B, C), like what @moskewcz did here.
  2. For different set of (m,n,k), select the most efficient tiling method for(m,n,k). But it's hard to determine the (m,n,k) tile size. So I implemented a series of kernels and use auto tuning.
  3. When modify certain lines of kernel, performance changes drastically. It's kind of tricky but I found the way of reading data like @moskewcz 's code: a_r[0] = ((global float8 const *)a)[a_off+0]; shows best performance in adreno 530.
  4. Using L1 cache is a good idea because the sampler helps for dealing boundary elements and reduce L2 utilization. It helps especially "very rectangular" matrixs like (m = 10, n = 1000, k = 20).
  5. Some codes will cause segment fault in clCreateKernel or -14 error in clEnqueueNDRangeKernel, but not very clear about the reasons.

What do you get recently?

Thu-Chris avatar May 17 '18 02:05 Thu-Chris

@Thu-Chris From your previous comments, I already tried 1, 2, 3 & 4.

For different set of (m,n,k), select the most efficient tiling method for(m,n,k). But it's hard to determine the (m,n,k) tile size. So I implemented a series of kernels and use auto tuning.

For now, I'm sticking onto square kernels for all M,N,K sizes (be it a square or rect matrix)

a_r[0] = ((global float8 const *)a)[a_off+0];

I tested on Adreno 506, vload8 & ((global float8 const *)a)[a_off+0]; provides almost the same performance (didnt notice a considerable difference).

Using L1 cache is a good idea because the sampler helps for dealing boundary elements and reduce L2 utilization. It helps especially "very rectangular" matrixs like (m = 10, n = 1000, k = 20).

Agreed. I used B mat as image & it improve the performance ~1.5x in Adreno 506. Now, looking for other possible options to improve the performance.

sivagnanamn avatar May 17 '18 02:05 sivagnanamn