NNlib.jl icon indicating copy to clipboard operation
NNlib.jl copied to clipboard

Slow ∇softmax! compared with generic version.

Open jumerckx opened this issue 3 years ago • 15 comments

The specialized backward-pass for softmax takes a lot longer than the generic implementation from NNlib.jl. The effect seems especially pronounced when the batch-dimension is larger.

Here's the code to reproduce this issue.

Below are the results of my benchmarks:

my CUDA versioninfo:

CUDA toolkit 11.4.1, artifact installation
CUDA driver 11.2.0
NVIDIA driver 461.9.0

Libraries: 
- CUBLAS: 11.5.4
- CURAND: 10.2.5
- CUFFT: 10.5.1
- CUSOLVER: 11.2.0
- CUSPARSE: 11.6.0
- CUPTI: 14.0.0
- NVML: 11.0.0+461.9
- CUDNN: 8.20.2 (for CUDA 11.4.0)
- CUTENSOR: 1.3.0 (for CUDA 11.2.0)

Toolchain:
- Julia: 1.6.1
- LLVM: 11.0.1
- PTX ISA support: 3.2, 4.0, 4.1, 4.2, 4.3, 5.0, 6.0, 6.1, 6.3, 6.4, 6.5, 7.0
- Device capability support: sm_35, sm_37, sm_50, sm_52, sm_53, sm_60, sm_61, sm_62, sm_70, sm_72, sm_75, sm_80

1 device:
  0: GeForce RTX 2070 with Max-Q Design (sm_75, 5.873 GiB / 8.000 GiB available)

jumerckx avatar Oct 09 '21 12:10 jumerckx

We could try with forcing using the CUDNN kernel directly as a debugging step, sidestepping the case of contiguous dimensions.

DhairyaLGandhi avatar Oct 09 '21 12:10 DhairyaLGandhi

Would you be able to post the CUDA.versioninfo() as well

DhairyaLGandhi avatar Oct 09 '21 12:10 DhairyaLGandhi

We could try with forcing using the CUDNN kernel directly as a debugging step, sidestepping the case of contiguous dimensions.

Indeed, just using _∇softmax! directly and thereby avoiding CUDNN runs in ~50µs instead of 2ms

Would you be able to post the CUDA.versioninfo() as well

Updated in OP.

jumerckx avatar Oct 09 '21 12:10 jumerckx

What happens if we use the CUDNN kernel instead?

I'd love a profile of the cudnn run, that should tell us what's going on

DhairyaLGandhi avatar Oct 09 '21 13:10 DhairyaLGandhi

I can't run Nvidia Nsight profiling on my machine but I timed all the lines which does confirm it is the kernel evaluation that causes the slowdown.

# ... after having ran code from OP ...
using NNlibCUDA
s = NNlibCUDA.softmaxdims(b, 1) # 1.030 μs (3 allocations: 112 bytes)
xDesc = NNlibCUDA.cudnnTensorDescriptor(reshape(b,s)) # 1.500 μs (5 allocations: 256 bytes)
alpha, beta = NNlibCUDA.scalingParameter(Float32,1), NNlibCUDA.scalingParameter(Float32,0) # 1.010 μs (4 allocations: 96 bytes)
NNlibCUDA.cudnnSoftmaxBackward(NNlibCUDA.handle(), NNlibCUDA.softmaxalgo(), NNlibCUDA.CUDNN_SOFTMAX_MODE_CHANNEL, alpha, xDesc, y_b, xDesc, Δ_b, beta, xDesc, out_b) # 1.998 ms (98 allocations: 9.55 KiB)

jumerckx avatar Oct 09 '21 15:10 jumerckx

What is the softmax algorithm being selected, and what happens if you force it to use the fast one? https://github.com/FluxML/NNlibCUDA.jl/blob/5eed446f379183e59bad583790fb7491b507743d/src/cudnn/softmax.jl#L66

ToucheSir avatar Oct 09 '21 15:10 ToucheSir

By default it's CUDNN_SOFTMAX_ACCURATE but using CUDA.math_mode!(CUDA.FAST_MATH) to use CUDNN_SOFTMAX_ACCURATE doesn't lead to a discernible difference in timing

jumerckx avatar Oct 09 '21 15:10 jumerckx

Here's a benchmark for PyTorch.

--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                            Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                   aten::softmax         5.26%       5.573ms        24.98%      26.467ms      26.467us       4.671ms         4.16%      25.506ms      25.506us          1000  
                  aten::_softmax        16.06%      17.020ms        19.72%      20.895ms      20.895us      20.836ms        18.56%      20.836ms      20.836us          1000  
                aten::empty_like         4.22%       4.467ms        11.29%      11.959ms       3.986us       0.000us         0.00%       0.000us       0.000us          3000  
                     aten::empty         8.82%       9.345ms         8.82%       9.345ms       2.336us       0.000us         0.00%       0.000us       0.000us          4000  
                       aten::sum        15.01%      15.900ms        17.23%      18.255ms      18.255us      27.518ms        24.51%      27.518ms      27.518us          1000  
                aten::as_strided         1.96%       2.081ms         1.96%       2.081ms       1.040us       0.000us         0.00%       0.000us       0.000us          2000  
                 aten::ones_like         4.94%       5.230ms        14.20%      15.042ms      15.042us       7.787ms         6.94%      13.961ms      13.961us          1000  
             aten::empty_strided         2.15%       2.283ms         2.15%       2.283ms       2.283us       0.000us         0.00%       0.000us       0.000us          1000  
                     aten::fill_         5.70%       6.045ms         5.70%       6.045ms       6.045us       6.173ms         5.50%       6.173ms       6.173us          1000  
                    SumBackward0         4.35%       4.611ms         9.20%       9.745ms       9.745us       9.016ms         8.03%       9.016ms       9.016us          1000  
                    aten::expand         3.70%       3.922ms         4.84%       5.134ms       5.134us       0.000us         0.00%       0.000us       0.000us          1000  
                 SoftmaxBackward         5.65%       5.989ms        34.40%      36.453ms      36.453us       5.843ms         5.21%      36.254ms      36.254us          1000  
    aten::_softmax_backward_data        12.96%      13.734ms        28.75%      30.464ms      30.464us      16.071ms        14.32%      30.411ms      30.411us          1000  
                       aten::mul         9.21%       9.762ms        11.71%      12.413ms      12.413us      14.340ms        12.77%      14.340ms      14.340us          1000  
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 105.961ms
Self CUDA time total: 112.255ms

Benchmark code:

import torch
import torch.utils.benchmark as benchmark

x = torch.rand(32, 10, 256, requires_grad=True).cuda()
with torch.autograd.profiler.profile(use_cuda=True) as prof:
    for _ in range(1000):
        sx = x.softmax(-1).sum()
        torch.autograd.grad([sx], [x])

print(prof.key_averages().table())

As you can see, the times are very similar to the "backup" implementation at https://github.com/FluxML/NNlibCUDA.jl/blob/06ccd9f5b0fa6d3bfc9c9d52dbf865a78d76a576/src/cudnn/softmax.jl#L33. I was curious to see how they managed to make cudnnSoftmaxBackward run faster, and to my surprise they don't seem to use it at all! Looking back at https://github.com/JuliaGPU/CUDA.jl/issues/599, it seems only the forward pass was thoroughly benchmarked and not the backwards function. If we can show that the native Julia version beats out the CuDNN one for a reasonably large number of sizes, dims, eltypes and math modes, I'd vote for using it by default and removing the CuDNN path.

ToucheSir avatar Oct 09 '21 17:10 ToucheSir

it's so weird CUDNN is so slow. Anyways, yes, if @jumerckx can compare performance on a few more size and results are the same we can remove the cudnn path

CarloLucibello avatar Oct 10 '21 07:10 CarloLucibello

We don't need to remove the cudnn paths, but make Julia default if this seems accurate enough for our use-cases.

DhairyaLGandhi avatar Oct 10 '21 08:10 DhairyaLGandhi

Looks like it really is CUDNN being slow:

image

maleadt avatar Oct 11 '21 15:10 maleadt

it's so weird CUDNN is so slow. Anyways, yes, if @jumerckx can compare performance on a few more size and results are the same we can remove the cudnn path

I've ran these timings across a bunch of sizes and haven't encountered one config in which CuDNN beat generic Julia (do please check my benchmarking code to assure I'm not doing anything stupid as this is the first time I do benchmarking like this). One thing that's really striking is the difference between enlarging the first or second dimension of the input. When the second dimension is large, CuDNN scales particularly bad (which is also what I ran into for the OP): Note the logarithmic axes. afbeelding afbeelding

I've updated the github gist with my benchmarking code but here's the (incomplete) essence:

# case for growing second dimension:
timings = []
for i in 1:12
  sz = (32, 2^i)
  
  x = CUDA.rand(sz...)
  dy = CUDA.rand(size(x)...)
  y = softmax(x)
  dx = similar(x)
  @show i
  time_generic = @belapsed CUDA.@sync NNlibCUDA._∇softmax!($dx, $dy, $x, $y; dims=1)
  
  s = NNlibCUDA.softmaxdims(x, 1)
  xDesc = NNlibCUDA.cudnnTensorDescriptor(reshape(x,s))
  alpha, beta = NNlibCUDA.scalingParameter(Float32,1), NNlibCUDA.scalingParameter(Float32,0)
  time_cudnn = @belapsed CUDA.@sync NNlibCUDA.cudnnSoftmaxBackward($NNlibCUDA.handle(), $NNlibCUDA.softmaxalgo(), $NNlibCUDA.CUDNN_SOFTMAX_MODE_CHANNEL, $alpha, $xDesc, $y, $xDesc, $dy, $beta, $xDesc, $dx)
  push!(timings, (sz, time_generic, time_cudnn))
end

jumerckx avatar Oct 11 '21 21:10 jumerckx

Those results look pretty conclusive, thanks @jumerckx! Does anyone want to do the honours?

ToucheSir avatar Oct 15 '21 17:10 ToucheSir

FWIW, CUDA.jl uses CUDNN 8.2 while 8.3 has been released, so we should probably re-evaluate at some point.

maleadt avatar Nov 17 '21 17:11 maleadt

Testing today, the generic version is usually slower than the NNlibCUDA one.

In tests below, the sole exception is ∇softmax when size(x) = (4096, 32).

These are all dims=1. Perhaps a bit more timing is justified, but on the whole it seems that what's presented above is no longer a good representation of the status.

using NNlib, CUDA, NNlibCUDA, Test
function simple_softmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T}
    max_ = fast_maximum(x; dims)
    # if all(isfinite, max_)
    #     @fastmath out .= exp.(x .- max_)
    # else
        @fastmath @. out = ifelse(isequal(max_,Inf), ifelse(isequal(x,Inf), 1, 0), exp(x - max_))
    # end
    tmp = dims isa Colon ? sum(out) : sum!(max_, out)
    out ./= tmp
end
fast_maximum(x::AbstractArray{T}; dims) where {T} = @fastmath reduce(max, x; dims, init = float(T)(-Inf))

x = CUDA.randn(10, 20);
y = similar(x);
@test simple_softmax!(y, x) ≈ softmax(x)

dy = CUDA.randn(10, 20);
@test NNlib.∇softmax(dy, x, y; dims=1) ≈ NNlib.∇softmax_data(dy, y; dims=1)

# @less NNlib.∇softmax_data(dy, y; dims=1)  # check this really is the simple version
### My tests

julia> for n in (10, 10^4, 10^6)  # Forward
           @show 100, 10, n
           x = CUDA.randn(100, 10, n)
           y = similar(x)
           @btime CUDA.@sync softmax!($y, $x)         # uses CUDNN
           @btime CUDA.@sync simple_softmax!($y, $x)  # simplest Julia version
       end
(100, 10, n) = (100, 10, 10)
  14.367 μs (11 allocations: 384 bytes)
  88.398 μs (167 allocations: 11.80 KiB)
(100, 10, n) = (100, 10, 10000)
  121.820 μs (11 allocations: 384 bytes)
  2.082 ms (216 allocations: 14.69 KiB)
(100, 10, n) = (100, 10, 1000000)
  11.155 ms (67 allocations: 3.70 KiB)
  223.190 ms (226 allocations: 14.88 KiB)

julia> for n in (10, 10^4, 10^5)  # Gradient, stopping 10x smaller
           GC.gc(true); CUDA.reclaim();
           @show 100, 10, n
           x = CUDA.randn(100, 10, n)
           y = softmax(x)
           dy = CUDA.randn(100, 10, n)
           @btime CUDA.@sync NNlib.∇softmax($dy, $x, $y; dims=1)   # uses CUDNN
           @btime CUDA.@sync NNlib.∇softmax_data($dy, $y; dims=1)  # just broadcasting
       end;
(100, 10, n) = (100, 10, 10)
  17.308 μs (21 allocations: 704 bytes)
  53.043 μs (107 allocations: 5.20 KiB)
(100, 10, n) = (100, 10, 10000)
  168.519 μs (21 allocations: 704 bytes)
  1.321 ms (159 allocations: 8.28 KiB)
(100, 10, n) = (100, 10, 100000)
  1.816 ms (76 allocations: 3.98 KiB)
  12.140 ms (162 allocations: 8.33 KiB)


### Same sizes as https://github.com/FluxML/NNlib.jl/issues/513

julia> for size in [(256, 10, 32), (32, 10, 256),     # first post
                    (32, 32), (32, 4096), (4096, 32)] # some points of graph
           @show size
           x = CUDA.randn(size...)
           y = similar(x)
           @btime CUDA.@sync softmax!($y, $x)         # uses CUDNN
           @btime CUDA.@sync simple_softmax!($y, $x)  # simplest Julia version
       end
size = (256, 10, 32)
  14.484 μs (11 allocations: 384 bytes)
  89.187 μs (167 allocations: 11.80 KiB)
size = (32, 10, 256)
  14.803 μs (11 allocations: 384 bytes)
  92.655 μs (169 allocations: 11.83 KiB)
size = (32, 32)
  15.233 μs (11 allocations: 384 bytes)
  89.899 μs (167 allocations: 10.67 KiB)
size = (32, 4096)
  15.240 μs (11 allocations: 384 bytes)
  91.431 μs (169 allocations: 10.70 KiB)
size = (4096, 32)
  18.150 μs (11 allocations: 384 bytes)
  164.365 μs (309 allocations: 17.38 KiB)

julia> for size in [(256, 10, 32), (32, 10, 256),  # Gradient
                    (32, 32), (32, 4096), (4096, 32)]
           @show size
           x = CUDA.randn(size...)
           y = softmax(x)
           dy = CUDA.randn(size...)
           @btime CUDA.@sync NNlib.∇softmax($dy, $x, $y; dims=1)   # old, uses CUDNN
           @btime CUDA.@sync NNlib.∇softmax_data($dy, $y; dims=1)  # new, just broadcasting
       end;
size = (256, 10, 32)
  18.666 μs (21 allocations: 704 bytes)
  55.459 μs (107 allocations: 5.20 KiB)
size = (32, 10, 256)
  19.019 μs (21 allocations: 704 bytes)
  55.353 μs (108 allocations: 5.22 KiB)
size = (32, 32)
  18.163 μs (21 allocations: 688 bytes)
  52.654 μs (107 allocations: 4.92 KiB)
size = (32, 4096)
  19.000 μs (21 allocations: 688 bytes)
  56.144 μs (108 allocations: 4.94 KiB)
size = (4096, 32)
  210.951 μs (21 allocations: 688 bytes)
  79.711 μs (167 allocations: 7.64 KiB)

(@v1.10) pkg> st CUDA
Status `~/.julia/environments/v1.10/Project.toml`
  [052768ef] CUDA v3.12.1

julia> CUDA.versioninfo()
CUDA toolkit 11.7, artifact installation
NVIDIA driver 510.47.3, for CUDA 11.6
CUDA driver 11.6

Libraries: 
- CUBLAS: 11.10.1
- CURAND: 10.2.10
- CUFFT: 10.7.2
- CUSOLVER: 11.3.5
- CUSPARSE: 11.7.3
- CUPTI: 17.0.0
- NVML: 11.0.0+510.47.3
- CUDNN: 8.30.2 (for CUDA 11.5.0)
- CUTENSOR: 1.4.0 (for CUDA 11.5.0)

Toolchain:
- Julia: 1.10.0-DEV.220
- LLVM: 14.0.6

mcabbott avatar Jan 10 '23 03:01 mcabbott