cupy icon indicating copy to clipboard operation
cupy copied to clipboard

[FEA] Added functionality to ElementwiseKernel

Open mnicely opened this issue 5 years ago • 13 comments

To begin, cuSignal has increased its use of CuPy's Elementwisk Kernel functionality with great success!

I would like to request two additional features.

Performance

  1. It is known that adding the __restrict__ flag to pointer parameters allows the compiler to perform additional optimizations. Also, adding const to read-only data. https://developer.nvidia.com/blog/cuda-pro-tip-optimize-pointer-aliasing/ It would be great those two options were possible for input and output (only __restrict__) parameters.

Functionality

  1. Passing in dtype for type inference. Currently, if a CuPy Elementwise Kernel can't infer a data type, one must be hardcoded. I have discovered it's faster not to create an empty array for output and just pass size= to a Elementwise Kernel. But then I have to hardcoded the data type of the output (if there's not input array).

As an example,

_bohman_kernel = cp.ElementwiseKernel(
    "",
    "float64 w",
    """
    double fac { abs( start + delta * ( i - 1 ) ) };
    if ( i != 0 && i != ( _ind.size() - 1 ) ) {
        w = ( 1 - fac ) * cos( M_PI * fac ) + 1.0 / M_PI * sin( M_PI * fac );
    } else {
        w = 0;
    }
    """,
    "_bohman_kernel",
    options=("-std=c++11",),
    loop_prep="double delta { 2.0 / ( _ind.size() - 1 ) }; \
               double start { -1.0 + delta };",
)

w = _bohman_kernel(size=M)

Therefore, if I want the option of float64 and float32 I need to create two kernel and logic to select correct kernel.

I would be great if I could pass dtype, maybe something like

_bohman_kernel = cp.ElementwiseKernel(
    "",
    "T w, C a",
    """
    T fac { abs( start + delta * ( i - 1 ) ) };
    if ( i != 0 && i != ( _ind.size() - 1 ) ) {
        w = ( 1 - fac ) * cos( M_PI * fac ) + 1.0 / M_PI * sin( M_PI * fac );
        a = C(0, w);
    } else {
        w = 0;
        a = C(w, 0);
    }
    """,
    "_bohman_kernel",
    options=("-std=c++11",),
    loop_prep="double delta { 2.0 / ( _ind.size() - 1 ) }; \
               double start { -1.0 + delta };",
    
)

w = _bohman_kernel(size=M, dtype=( ("T", float64), ("C", complex128) ), )

@z-ryan1 @awthomp @leofang

mnicely avatar Oct 13 '20 14:10 mnicely

I think we definitely have to add the __restrict__ to both elementwise and reductions. I will work on an implementation and do some benchmarking

emcastillo avatar Oct 14 '20 02:10 emcastillo

Regarding 1 I've been doing some survey,

1 st, all the generated elementwise kernels uses CArray objects as parameters instead of plain pointers, so we need to decouple the actual buffer from the CArray for __restrict__ and const to be used. This is quite a heavy change so to motivate it, I tried to measure the potential speedup we could get.

I went ahead and try a simple RawKernel

raw_kernel = cupy.RawKernel(
    r"""
extern "C" __global__
void my_add(const float* x1, const int* c, float* y) {
    int tid = blockDim.x * blockIdx.x + threadIdx.x;
    y[tid] = x1[c[tid]];
}
""",
    "my_add",
)

This is the same one than your reference link. However, when measuring in a A100 & P100 I saw 0 speedup. Apparently __restrict__ only helps with data reuse so when I changed the kernel to do

    for (int i=0; i<64; i++)
        y[tid] = x1[c[tid]];

instead, I saw a speedup going from 40 us to 7 us.

However, ufuncs and most of our routines do not do these kind of heavy data reuse (I have to check reduction though).

emcastillo avatar Oct 14 '20 04:10 emcastillo

For CUB reduction, I actually just tried the restrict trick earlier this week, but didn't find any performance boost. It's as simple as adding __restrict__ to this line: https://github.com/cupy/cupy/blob/ded0d3e8b4a3e4d44067a358bb9c9cf7fbfdabff/cupy/core/_cub_reduction.pyx#L347 so that we have const void* __restrict__ for input array and void* __restrict__ for output array. @mnicely Is it because the compiler does not see the type information until I cast void*to something like float* etc and thus can't do any optimization?

leofang avatar Oct 14 '20 04:10 leofang

I added a horrible test implementation of restrict here

https://github.com/emcastillo/cupy/commit/e3562440bd61355aeb0551306d9a0a3143fc31b2

I haven't seen any performance improvement yet though, can you @mnicely test this and tell me if you see anything significant?

This is a simple test code

import cupy
import cupyx


a_res = cupy.core.ElementwiseKernel(
    'restrict T x, restrict T y',
    'T z',
    'for(int i=0; i<100; i++) z = (x - y) * (x - y)',
    'squared_diff_generic')

a_nores = cupy.core.ElementwiseKernel(
    'T x, T y',
    'T z',
    'for(int i=0; i<100; i++) z = (x - y) * (x - y)',
    'squared_diff_generic')
x = cupy.arange(10000).astype(cupy.float32)
y = cupy.arange(10000).astype(cupy.float32)+1.1

print(cupyx.time.repeat(a_nores, (x, y)))
print(cupyx.time.repeat(a_res, (x, y)))

Fixed the error now

emcastillo avatar Oct 14 '20 09:10 emcastillo

On Kepler generation GPUs, adding __restrict__ improves performance, but modern GPUs, it rarely improve performance..

anaruse avatar Oct 15 '20 10:10 anaruse

All, I really appreciate the quick responses.

@emcastillo The example above was more for visual example for (2). But I have written a few elementwise kernels that require data reuse. I don't expect improvements to always be significant. I guess it's more of a programming style I've developed over time, just to give the compiler as much help as possible.

@leofang I'm not really sure. I'll ask the compiler folks.

I'll confirm with the compiler team if this is technique that is no longer needed on GPUs and get back to you.

Any thoughts on (2)?

Lastly, I noticed that blocksize defaults to 128, do you utilize __launch_bounds__(128) to allow the compiler to optimize register usage?

mnicely avatar Oct 15 '20 13:10 mnicely

@emcastillo I did confirm what you saw in your example is as expected, or little-to-no performance improvement. But if the elementwise kernel were to become more complex (adding for-loops and data reuse), adding __restrict__ DOES give the compiler more information to make further optimizations.

All that being said, because it's a heavy change I see no need in push the idea further. Again thanks for looking into it.

mnicely avatar Oct 15 '20 17:10 mnicely

Further information for those interested.

__global__
void copy(const dtype* const __restrict__ src, dtype* const __restrict__ dst, const size_t dlen) {
  #pragma unroll
  for (int i = 0; i < 2; i++) {
    const int idx = i*blockDim.x*gridDim.x + blockIdx.x*blockDim.x + threadIdx.x;
    dst[idx] = src[idx];
  }
}

will reorder loads and stores since it knows src and dst don't overlap:

        /*00a0*/                   LDG.E.CONSTANT.SYS R3, [R2];
        /*00b0*/                   LDG.E.CONSTANT.SYS R7, [R6];
        /*00c0*/                   IMAD.WIDE R4, R4, R9, c[0x0][0x168];
        /*00d0*/                   IMAD.WIDE R8, R0, R9, c[0x0][0x168];
        /*00e0*/                   STG.E.SYS [R4], R3;
        /*00f0*/                   STG.E.SYS [R8], R7;

without restrict the loads and stores don't get reodered:

        /*0070*/                   LDG.E.SYS R3, [R2];
        /*0080*/                   IADD3 R0, R0, c[0x0][0xc], RZ;
        /*0090*/                   IMAD R0, R0, c[0x0][0x0], R5;
        /*00a0*/                   IMAD.WIDE R4, R4, R9, c[0x0][0x168];
        /*00b0*/                   IMAD.WIDE R6, R0, R9, c[0x0][0x160];
        /*00c0*/                   STG.E.SYS [R4], R3;
        /*00d0*/                   LDG.E.SYS R7, [R6];
        /*00e0*/                   IMAD.WIDE R8, R0, R9, c[0x0][0x168];
        /*00f0*/                   STG.E.SYS [R8], R7;

mnicely avatar Oct 15 '20 19:10 mnicely

@emcastillo The example above was more for visual example for (2). But I have written a few elementwise kernels that require data reuse. I don't expect improvements to always be significant. I guess it's more of a programming style I've developed over time, just to give the compiler as much help as possible.

Can you test those with my branch? or can you post a snipper so I can measure them?

Probably if an Elementwise kernel gets too complicated and it can benefit from the restrict qualifier it might be better to have it just as a raw kernel :).

Also, @leofang suspects that all the indexing logic in elementwise kernels might be preventing the compiler to further optimize the code

emcastillo avatar Oct 16 '20 02:10 emcastillo

I'll try to test your branch next week! I have to teach a course today.

You're probably right. I guess I've gotten lazy with easiness of the Elementwise Kernel :joy:

mnicely avatar Oct 16 '20 12:10 mnicely

@emcastillo I'm not seeing huge improvements in my testing. 10% is some cases.

Question, how do Elementwise kernels handle inputs when they are also the output?

Scenario:

  1. Input is the output, therefore only add __restrict__
  2. Input is different from output, therefore add const and __restrict__ to input and __restrict__ to output.

mnicely avatar Oct 19 '20 14:10 mnicely

@mnicely Dumb question, probably unrelated to the main thread: I noticed from your example code and many cuSignal PRs that you seem to favor direct-list-initialization even for primary data types:

double fac { abs( start + delta * ( i - 1 ) ) };

Is there a performance reason, or just a matter of style taste?

leofang avatar Oct 20 '20 03:10 leofang

@leofang Not dumb at all :smile: it's just personal preference. I like how it catches illegal narrowing at compile time.

mnicely avatar Oct 20 '20 11:10 mnicely