NNlib.jl
NNlib.jl copied to clipboard
Slow ∇softmax! compared with generic version.
The specialized backward-pass for softmax takes a lot longer than the generic implementation from NNlib.jl. The effect seems especially pronounced when the batch-dimension is larger.
Here's the code to reproduce this issue.
Below are the results of my benchmarks:
my CUDA versioninfo:
CUDA toolkit 11.4.1, artifact installation
CUDA driver 11.2.0
NVIDIA driver 461.9.0
Libraries:
- CUBLAS: 11.5.4
- CURAND: 10.2.5
- CUFFT: 10.5.1
- CUSOLVER: 11.2.0
- CUSPARSE: 11.6.0
- CUPTI: 14.0.0
- NVML: 11.0.0+461.9
- CUDNN: 8.20.2 (for CUDA 11.4.0)
- CUTENSOR: 1.3.0 (for CUDA 11.2.0)
Toolchain:
- Julia: 1.6.1
- LLVM: 11.0.1
- PTX ISA support: 3.2, 4.0, 4.1, 4.2, 4.3, 5.0, 6.0, 6.1, 6.3, 6.4, 6.5, 7.0
- Device capability support: sm_35, sm_37, sm_50, sm_52, sm_53, sm_60, sm_61, sm_62, sm_70, sm_72, sm_75, sm_80
1 device:
0: GeForce RTX 2070 with Max-Q Design (sm_75, 5.873 GiB / 8.000 GiB available)
We could try with forcing using the CUDNN kernel directly as a debugging step, sidestepping the case of contiguous dimensions.
Would you be able to post the CUDA.versioninfo()
as well
We could try with forcing using the CUDNN kernel directly as a debugging step, sidestepping the case of contiguous dimensions.
Indeed, just using _∇softmax!
directly and thereby avoiding CUDNN runs in ~50µs instead of 2ms
Would you be able to post the
CUDA.versioninfo()
as well
Updated in OP.
What happens if we use the CUDNN kernel instead?
I'd love a profile of the cudnn run, that should tell us what's going on
I can't run Nvidia Nsight profiling on my machine but I timed all the lines which does confirm it is the kernel evaluation that causes the slowdown.
# ... after having ran code from OP ...
using NNlibCUDA
s = NNlibCUDA.softmaxdims(b, 1) # 1.030 μs (3 allocations: 112 bytes)
xDesc = NNlibCUDA.cudnnTensorDescriptor(reshape(b,s)) # 1.500 μs (5 allocations: 256 bytes)
alpha, beta = NNlibCUDA.scalingParameter(Float32,1), NNlibCUDA.scalingParameter(Float32,0) # 1.010 μs (4 allocations: 96 bytes)
NNlibCUDA.cudnnSoftmaxBackward(NNlibCUDA.handle(), NNlibCUDA.softmaxalgo(), NNlibCUDA.CUDNN_SOFTMAX_MODE_CHANNEL, alpha, xDesc, y_b, xDesc, Δ_b, beta, xDesc, out_b) # 1.998 ms (98 allocations: 9.55 KiB)
What is the softmax algorithm being selected, and what happens if you force it to use the fast one? https://github.com/FluxML/NNlibCUDA.jl/blob/5eed446f379183e59bad583790fb7491b507743d/src/cudnn/softmax.jl#L66
By default it's CUDNN_SOFTMAX_ACCURATE
but using CUDA.math_mode!(CUDA.FAST_MATH)
to use CUDNN_SOFTMAX_ACCURATE
doesn't lead to a discernible difference in timing
Here's a benchmark for PyTorch.
-------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
-------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::softmax 5.26% 5.573ms 24.98% 26.467ms 26.467us 4.671ms 4.16% 25.506ms 25.506us 1000
aten::_softmax 16.06% 17.020ms 19.72% 20.895ms 20.895us 20.836ms 18.56% 20.836ms 20.836us 1000
aten::empty_like 4.22% 4.467ms 11.29% 11.959ms 3.986us 0.000us 0.00% 0.000us 0.000us 3000
aten::empty 8.82% 9.345ms 8.82% 9.345ms 2.336us 0.000us 0.00% 0.000us 0.000us 4000
aten::sum 15.01% 15.900ms 17.23% 18.255ms 18.255us 27.518ms 24.51% 27.518ms 27.518us 1000
aten::as_strided 1.96% 2.081ms 1.96% 2.081ms 1.040us 0.000us 0.00% 0.000us 0.000us 2000
aten::ones_like 4.94% 5.230ms 14.20% 15.042ms 15.042us 7.787ms 6.94% 13.961ms 13.961us 1000
aten::empty_strided 2.15% 2.283ms 2.15% 2.283ms 2.283us 0.000us 0.00% 0.000us 0.000us 1000
aten::fill_ 5.70% 6.045ms 5.70% 6.045ms 6.045us 6.173ms 5.50% 6.173ms 6.173us 1000
SumBackward0 4.35% 4.611ms 9.20% 9.745ms 9.745us 9.016ms 8.03% 9.016ms 9.016us 1000
aten::expand 3.70% 3.922ms 4.84% 5.134ms 5.134us 0.000us 0.00% 0.000us 0.000us 1000
SoftmaxBackward 5.65% 5.989ms 34.40% 36.453ms 36.453us 5.843ms 5.21% 36.254ms 36.254us 1000
aten::_softmax_backward_data 12.96% 13.734ms 28.75% 30.464ms 30.464us 16.071ms 14.32% 30.411ms 30.411us 1000
aten::mul 9.21% 9.762ms 11.71% 12.413ms 12.413us 14.340ms 12.77% 14.340ms 14.340us 1000
-------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 105.961ms
Self CUDA time total: 112.255ms
Benchmark code:
import torch
import torch.utils.benchmark as benchmark
x = torch.rand(32, 10, 256, requires_grad=True).cuda()
with torch.autograd.profiler.profile(use_cuda=True) as prof:
for _ in range(1000):
sx = x.softmax(-1).sum()
torch.autograd.grad([sx], [x])
print(prof.key_averages().table())
As you can see, the times are very similar to the "backup" implementation at https://github.com/FluxML/NNlibCUDA.jl/blob/06ccd9f5b0fa6d3bfc9c9d52dbf865a78d76a576/src/cudnn/softmax.jl#L33. I was curious to see how they managed to make cudnnSoftmaxBackward
run faster, and to my surprise they don't seem to use it at all! Looking back at https://github.com/JuliaGPU/CUDA.jl/issues/599, it seems only the forward pass was thoroughly benchmarked and not the backwards function. If we can show that the native Julia version beats out the CuDNN one for a reasonably large number of sizes, dims, eltypes and math modes, I'd vote for using it by default and removing the CuDNN path.
it's so weird CUDNN is so slow. Anyways, yes, if @jumerckx can compare performance on a few more size and results are the same we can remove the cudnn path
We don't need to remove the cudnn paths, but make Julia default if this seems accurate enough for our use-cases.
Looks like it really is CUDNN being slow:
it's so weird CUDNN is so slow. Anyways, yes, if @jumerckx can compare performance on a few more size and results are the same we can remove the cudnn path
I've ran these timings across a bunch of sizes and haven't encountered one config in which CuDNN beat generic Julia (do please check my benchmarking code to assure I'm not doing anything stupid as this is the first time I do benchmarking like this).
One thing that's really striking is the difference between enlarging the first or second dimension of the input. When the second dimension is large, CuDNN scales particularly bad (which is also what I ran into for the OP):
Note the logarithmic axes.
I've updated the github gist with my benchmarking code but here's the (incomplete) essence:
# case for growing second dimension:
timings = []
for i in 1:12
sz = (32, 2^i)
x = CUDA.rand(sz...)
dy = CUDA.rand(size(x)...)
y = softmax(x)
dx = similar(x)
@show i
time_generic = @belapsed CUDA.@sync NNlibCUDA._∇softmax!($dx, $dy, $x, $y; dims=1)
s = NNlibCUDA.softmaxdims(x, 1)
xDesc = NNlibCUDA.cudnnTensorDescriptor(reshape(x,s))
alpha, beta = NNlibCUDA.scalingParameter(Float32,1), NNlibCUDA.scalingParameter(Float32,0)
time_cudnn = @belapsed CUDA.@sync NNlibCUDA.cudnnSoftmaxBackward($NNlibCUDA.handle(), $NNlibCUDA.softmaxalgo(), $NNlibCUDA.CUDNN_SOFTMAX_MODE_CHANNEL, $alpha, $xDesc, $y, $xDesc, $dy, $beta, $xDesc, $dx)
push!(timings, (sz, time_generic, time_cudnn))
end
Those results look pretty conclusive, thanks @jumerckx! Does anyone want to do the honours?
FWIW, CUDA.jl uses CUDNN 8.2 while 8.3 has been released, so we should probably re-evaluate at some point.
Testing today, the generic version is usually slower than the NNlibCUDA one.
In tests below, the sole exception is ∇softmax
when size(x) = (4096, 32)
.
These are all dims=1
. Perhaps a bit more timing is justified, but on the whole it seems that what's presented above is no longer a good representation of the status.
using NNlib, CUDA, NNlibCUDA, Test
function simple_softmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T}
max_ = fast_maximum(x; dims)
# if all(isfinite, max_)
# @fastmath out .= exp.(x .- max_)
# else
@fastmath @. out = ifelse(isequal(max_,Inf), ifelse(isequal(x,Inf), 1, 0), exp(x - max_))
# end
tmp = dims isa Colon ? sum(out) : sum!(max_, out)
out ./= tmp
end
fast_maximum(x::AbstractArray{T}; dims) where {T} = @fastmath reduce(max, x; dims, init = float(T)(-Inf))
x = CUDA.randn(10, 20);
y = similar(x);
@test simple_softmax!(y, x) ≈ softmax(x)
dy = CUDA.randn(10, 20);
@test NNlib.∇softmax(dy, x, y; dims=1) ≈ NNlib.∇softmax_data(dy, y; dims=1)
# @less NNlib.∇softmax_data(dy, y; dims=1) # check this really is the simple version
### My tests
julia> for n in (10, 10^4, 10^6) # Forward
@show 100, 10, n
x = CUDA.randn(100, 10, n)
y = similar(x)
@btime CUDA.@sync softmax!($y, $x) # uses CUDNN
@btime CUDA.@sync simple_softmax!($y, $x) # simplest Julia version
end
(100, 10, n) = (100, 10, 10)
14.367 μs (11 allocations: 384 bytes)
88.398 μs (167 allocations: 11.80 KiB)
(100, 10, n) = (100, 10, 10000)
121.820 μs (11 allocations: 384 bytes)
2.082 ms (216 allocations: 14.69 KiB)
(100, 10, n) = (100, 10, 1000000)
11.155 ms (67 allocations: 3.70 KiB)
223.190 ms (226 allocations: 14.88 KiB)
julia> for n in (10, 10^4, 10^5) # Gradient, stopping 10x smaller
GC.gc(true); CUDA.reclaim();
@show 100, 10, n
x = CUDA.randn(100, 10, n)
y = softmax(x)
dy = CUDA.randn(100, 10, n)
@btime CUDA.@sync NNlib.∇softmax($dy, $x, $y; dims=1) # uses CUDNN
@btime CUDA.@sync NNlib.∇softmax_data($dy, $y; dims=1) # just broadcasting
end;
(100, 10, n) = (100, 10, 10)
17.308 μs (21 allocations: 704 bytes)
53.043 μs (107 allocations: 5.20 KiB)
(100, 10, n) = (100, 10, 10000)
168.519 μs (21 allocations: 704 bytes)
1.321 ms (159 allocations: 8.28 KiB)
(100, 10, n) = (100, 10, 100000)
1.816 ms (76 allocations: 3.98 KiB)
12.140 ms (162 allocations: 8.33 KiB)
### Same sizes as https://github.com/FluxML/NNlib.jl/issues/513
julia> for size in [(256, 10, 32), (32, 10, 256), # first post
(32, 32), (32, 4096), (4096, 32)] # some points of graph
@show size
x = CUDA.randn(size...)
y = similar(x)
@btime CUDA.@sync softmax!($y, $x) # uses CUDNN
@btime CUDA.@sync simple_softmax!($y, $x) # simplest Julia version
end
size = (256, 10, 32)
14.484 μs (11 allocations: 384 bytes)
89.187 μs (167 allocations: 11.80 KiB)
size = (32, 10, 256)
14.803 μs (11 allocations: 384 bytes)
92.655 μs (169 allocations: 11.83 KiB)
size = (32, 32)
15.233 μs (11 allocations: 384 bytes)
89.899 μs (167 allocations: 10.67 KiB)
size = (32, 4096)
15.240 μs (11 allocations: 384 bytes)
91.431 μs (169 allocations: 10.70 KiB)
size = (4096, 32)
18.150 μs (11 allocations: 384 bytes)
164.365 μs (309 allocations: 17.38 KiB)
julia> for size in [(256, 10, 32), (32, 10, 256), # Gradient
(32, 32), (32, 4096), (4096, 32)]
@show size
x = CUDA.randn(size...)
y = softmax(x)
dy = CUDA.randn(size...)
@btime CUDA.@sync NNlib.∇softmax($dy, $x, $y; dims=1) # old, uses CUDNN
@btime CUDA.@sync NNlib.∇softmax_data($dy, $y; dims=1) # new, just broadcasting
end;
size = (256, 10, 32)
18.666 μs (21 allocations: 704 bytes)
55.459 μs (107 allocations: 5.20 KiB)
size = (32, 10, 256)
19.019 μs (21 allocations: 704 bytes)
55.353 μs (108 allocations: 5.22 KiB)
size = (32, 32)
18.163 μs (21 allocations: 688 bytes)
52.654 μs (107 allocations: 4.92 KiB)
size = (32, 4096)
19.000 μs (21 allocations: 688 bytes)
56.144 μs (108 allocations: 4.94 KiB)
size = (4096, 32)
210.951 μs (21 allocations: 688 bytes)
79.711 μs (167 allocations: 7.64 KiB)
(@v1.10) pkg> st CUDA
Status `~/.julia/environments/v1.10/Project.toml`
[052768ef] CUDA v3.12.1
julia> CUDA.versioninfo()
CUDA toolkit 11.7, artifact installation
NVIDIA driver 510.47.3, for CUDA 11.6
CUDA driver 11.6
Libraries:
- CUBLAS: 11.10.1
- CURAND: 10.2.10
- CUFFT: 10.7.2
- CUSOLVER: 11.3.5
- CUSPARSE: 11.7.3
- CUPTI: 17.0.0
- NVML: 11.0.0+510.47.3
- CUDNN: 8.30.2 (for CUDA 11.5.0)
- CUTENSOR: 1.4.0 (for CUDA 11.5.0)
Toolchain:
- Julia: 1.10.0-DEV.220
- LLVM: 14.0.6