GPUArrays.jl icon indicating copy to clipboard operation
GPUArrays.jl copied to clipboard

Generic `softmax`

Open MikeInnes opened this issue 7 years ago • 15 comments

Aside from adapt, this is probably the biggest blocker to running Flux with CLArrays right now.

MikeInnes avatar Feb 15 '18 11:02 MikeInnes

Should we create GPUNNlib for that?

SimonDanisch avatar Feb 15 '18 11:02 SimonDanisch

Or maybe just make GPUArrayMath.jl be a library of generic math functions implemented for GPUArrays and have it in there? That would be a good place to put other things like stencil implementations.

ChrisRackauckas avatar Feb 15 '18 12:02 ChrisRackauckas

Would the problem be solved if the naive softmax in https://github.com/JuliaGPU/CuArrays.jl/issues/45 were provided in NNLib for AbstractArrays rather than a CPU-specialized multithreaded implementation? That could then be overridden for ::GPUArray etc. with the multithreaded one used for ::Array? None of this is going to really make sense until Cassette lets us dispatch on "yeah, this is an abstract array of duals, but it's on the GPU!"

jekbradbury avatar Feb 16 '18 05:02 jekbradbury

Yeah, that's probably the way to do it, and would delay having to write a kernel for it. It's annoying because we can't dispatch on CPU-ness, so we have to opt in all types like Array that support multithreading.

MikeInnes avatar Feb 16 '18 09:02 MikeInnes

@SimonDanisch perhaps PyTorch's version can be ported as a generic kernel?

MikeInnes avatar Feb 24 '18 20:02 MikeInnes

Can you help me with some tests? I have a version that isn't erroring anymore, but the values seem to be wrong.

HostSoftMaxForward(
        input::AbstractArray{T}, output::AbstractArray{T},
        outer_size::Int, dim_size::Int, inner_size::Int, dim::Int,
        Epilogue = SoftMaxForwardEpilogue
) 

I'm not entirely sure what would be good values for this...

SimonDanisch avatar Feb 26 '18 12:02 SimonDanisch

Here's a basic implementation:

function softmax!(y, x; dim=1)
    tmp = maximum(x, dim)
    y .= exp.(x .- tmp)
    sum!(tmp, y)
    y ./= sum(y, dim)
end

softmax(x; dim = 1) = softmax!(similar(x), x, dim = dim)

It's not totally obvious to me what the sizes are supposed to be though, if that's what you mean. I guess there'll be wrapper code for it somewhere though.

MikeInnes avatar Feb 26 '18 15:02 MikeInnes

Yeah, I meant those - so they're gpu specific I guess....

SimonDanisch avatar Feb 26 '18 16:02 SimonDanisch

I don't remember exactly, but one of the three sizes definitely refers to the size of the dimension being reduced over and the two others might be the size of the (single, implicitly collapsed) dimensions major and minor to the reduction dimension. In that case if the input is rand(2, 3, 4, 5, 6) and the softmax dimension is dim=3 these would be 5x6=30, 4, and 2x3=6?

jekbradbury avatar Feb 26 '18 16:02 jekbradbury

Ok, I run out of time debugging this, but anyone can take a look at it here: https://github.com/JuliaGPU/GPUArrays.jl/pull/106

SimonDanisch avatar Feb 26 '18 22:02 SimonDanisch

I was curious whether the generic, broadcast-based implementation could be made as fast as a handwritten softmax, even in principle, so I started by running the following benchmark (times are kernel times from nvprof):

using CuArrays

# times are for the last benchmark in this file, on a GT 750M

function softmax!(out::AbstractArray, x::AbstractArray; dim=1)
  tmp = maximum(x, dim) # mapreducedim_kernel 70.6 ms
  out .= exp.(x .- tmp) # broadcast_kernel    36.1 ms
  sum!(tmp, out)        # mapreducedim_kernel 72.2 ms
  out ./= tmp           # broadcast_kernel    27.8 ms
end                     # total              207.3 ms
#                      vs cudnnSoftmaxForward 19.6 ms

function bench(sizes...)
    x = cu(rand(sizes...)); y = similar(x);
    CUDAdrv.synchronize()
    for _ in 1:3
        softmax!(y, x, dim=1);
        CUDAdrv.synchronize()
        CuArrays.CUDNN.cudnnSoftmaxForward(x, y)
        CUDAdrv.synchronize()
    end
end
# attention use case
bench(64, 64 * 32)
# classification use cases
bench(16384, 32)
bench(16384, 64 * 32) # times are using these sizes

The results are the same for the GPUArrays mapreducedim kernel, which is essentially identical to the CuArrays one. This suggests that there's a lot of performance left on the table with the default schedule used for broadcast_kernel and mapreducedim_kernel (at least on this tiny GPU), and my initial question isn't really relevant yet because each of the four sub-kernels takes longer than the overall cuDNN softmax.

jekbradbury avatar Mar 15 '18 16:03 jekbradbury

Thanks for those numbers, that's pretty interesting :) I hope to have some time to look into this!

SimonDanisch avatar Mar 15 '18 18:03 SimonDanisch

V100 numbers:

using CuArrays

# times are for the last benchmark in this file, on GT 750M and V100

macro gtime(expr)
   quote
       t = @elapsed begin
           x = $(esc(expr))
           CUDAdrv.synchronize(x)
       end
       println("time for (", $(QuoteNode(expr)), "): ", t * 1000, "ms")
   end
end

function softmax!(out::AbstractArray, x::AbstractArray; dim=1)
                              # kernel name          GT 750M   V100
 @gtime tmp = maximum(x, dim) # mapreducedim_kernel  70.6 ms  3.81 ms
 @gtime out .= exp.(x .- tmp) # broadcast_kernel     36.1 ms  0.65 ms
 @gtime sum!(tmp, out)        # mapreducedim_kernel  72.2 ms  3.76 ms
 @gtime out ./= tmp           # broadcast_kernel     27.8 ms  0.45 ms
end                           # total               207.3 ms  9.35 ms
#                            vs cudnnSoftmaxForward  19.6 ms  0.82 ms

add_dims(x::AbstractArray) = reshape(x, (1, 1, size(x)...))

function bench(sizes...)
   x = cu(rand(sizes...)); y = similar(x);
   for _ in 1:3
       @gtime softmax!(y, x, dim=1);
       xx = add_dims(x); yy = add_dims(y)
       @gtime CuArrays.CUDNN.cudnnSoftmaxForward(xx, yy)
   end
end
# attention use case
bench(64, 64 * 32)
# classification use cases
bench(16384, 32)
bench(16384, 64 * 32) # times are using these sizes

jekbradbury avatar Mar 15 '18 22:03 jekbradbury

I have 350 μs with this CUDAnative kernel (so more than twice as fast as cuDNN!)

jekbradbury avatar Jun 10 '18 03:06 jekbradbury

(whoops, that was blatantly buggy; after correcting it and updating the gist I'm matching cuDNN's time almost exactly)

jekbradbury avatar Jun 11 '18 03:06 jekbradbury