NNlib.jl icon indicating copy to clipboard operation
NNlib.jl copied to clipboard

Add mask to `softmax`

Open mcabbott opened this issue 2 years ago • 4 comments

Does this, should be essentially free:

julia> mask = Bool[1 1 1 0 0; 0 1 1 1 0; 0 0 1 1 1]
3×5 Matrix{Bool}:
 1  1  1  0  0
 0  1  1  1  0
 0  0  1  1  1

julia> softmax(repeat([1,2,3], 1, 5); mask)
3×5 Matrix{Float64}:
 1.0  0.268941  0.0900306  0.0       0.0
 0.0  0.731059  0.244728   0.268941  0.0
 0.0  0.0       0.665241   0.731059  1.0

PR Checklist

  • [x] Tests are added
  • [x] Documentation, if applicable

mcabbott avatar Jan 09 '23 22:01 mcabbott

Does this work for CuArray? I mean IIUC NNlibCUDA is using the CUDNN softmax

chengchingwen avatar Jan 10 '23 00:01 chengchingwen

I thought NNlib stopped using the NNlibCUDA one, as it was slower.

mcabbott avatar Jan 10 '23 02:01 mcabbott

I thought NNlib stopped using the NNlibCUDA one, as it was slower.

According to Cthulhu, nope:

julia> @descend softmax(cu(randn(3,3)); dims=1) 
(::NNlib.var"#softmax##kw")(::Any, ::typeof(softmax), x::T) where T<:(CuArray) in NNlibCUDA at /home/peter/.julia/packages/NNlibCUDA/gzTJY/src/cudnn/softmax.jl:11                                                                                     ∘ ─ %0 = invoke softmax##kw(::NamedTuple{…},::#softmax,::CuArray{…})::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}       11 1 ─ %1 = Base.getfield(@_2, :dims)::Int64                                                        │╻    getindex
   2 ─      nothing::Nothing                                                                        │                     
   3 ─ %3 = Base.getfield(x, :dims)::Tuple{Int64, Int64}                                            ││╻╷╷  similar
   │   %4 = invoke CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}(CUDA.undef::UndefInitializer, %3::Tuple{Int64, Int64})::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}                                                     
   └── %5 = NNlibCUDA.softmax!::Core.Const(NNlib.softmax!)                                          ││   
   4 ─      nothing::Nothing                                                                        │                     
   5 ─ %7 = invoke NNlibCUDA.:(var"#softmax!#50")(%1::Int64, %5::typeof(softmax!), %4::CuArray{Float32, 2, CUDA.Mem.Device
Buffer}, x::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}               
   └──      goto FluxML/NNlibCUDA.jl#6                                                                                 │││                   
   6 ─      goto FluxML/NNlibCUDA.jl#7                                                                                 ││   
   7 ─      return %7                                                                               │    
Select a call to descend into or ↩ to ascend. [q]uit. [b]ookmark.                                        
Toggles: [o]ptimize, [w]arn, [h]ide type-stable statements, [d]ebuginfo, [r]emarks, [e]ffects, [i]nlining costs, [t]ype an
notations, [s]yntax highlight for Source/LLVM/Native.                                      
Show: [S]ource code, [A]ST, [T]yped code, [L]LVM IR, [N]ative code
Actions: [E]dit source code, [R]evise and redisplay
Advanced: dump [P]arams cache.
   %4 = invoke CuArray(::UndefInitializer,::Tuple{Int64, Int64})::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}
 • %7 = invoke #softmax!#50(::Int64,::#softmax!,::CuArray{…},::CuArray{…})::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}
   ↩
var"#softmax!#50"(dims, ::typeof(softmax!), y::T, x::T) where T<:(CuArray) in NNlibCUDA at /home/peter/.julia/packages/NNl
ibCUDA/gzTJY/src/cudnn/softmax.jl:70
   ∘ ─ %0 = invoke #softmax!#50(::Int64,::#softmax!,::CuArray{…},::CuArray{…})::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}
71 1 ─ %1  = invoke NNlibCUDA.softmaxdims(x::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, dims::Int64)::Union{Nothing, NTup
le{4, Int64}}
72 │   %2  = (%1 === NNlibCUDA.nothing)::Bool                                       │      
   └──       goto FluxML/NNlibCUDA.jl#6 if not %2                                                      │      
   2 ─ %4  = NNlibCUDA._softmax!::Core.Const(NNlibCUDA._softmax!)                   │      
   3 ─       nothing::Nothing                                                       │      
   4 ─ %6  = invoke NNlibCUDA.:(var"#_softmax!#44")(dims::Int64, %4::typeof(NNlibCUDA._softmax!), y::CuArray{Float32, 2, C
UDA.Mem.DeviceBuffer}, x::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}
   └──       goto FluxML/NNlibCUDA.jl#5                                                                ││     
   5 ─       return %6                                                              │      
73 6 ─ %9  = π (%1, NTuple{4, Int64})                                               │      
   │   %10 = invoke NNlibCUDA.reshape(y::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, %9::NTuple{4, Int64})::CuArray{Float3
2, 4, CUDA.Mem.DeviceBuffer}
   │   %11 = π (%1, NTuple{4, Int64})                                               │      
   │   %12 = invoke NNlibCUDA.reshape(x::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, %11::NTuple{4, Int64})::CuArray{Float
32, 4, CUDA.Mem.DeviceBuffer}
   │   %13 = invoke CUDA.task_local_state!()::CUDA.TaskLocalState                   ││╻      math_mode
   │   %14 = Base.getfield(%13, :math_mode)::CUDA.MathMode                          │││╻      getproperty
   │   %15 = CUDA.FAST_MATH::Core.Const(CUDA.FAST_MATH)                             ││     
   │   %16 = (%14 === %15)::Bool                                                    ││     
   └──       goto FluxML/NNlibCUDA.jl#8 if not %16                                                     ││     
   7 ─ %18 = NNlibCUDA.CUDNN_SOFTMAX_FAST::Core.Const(CUDA.CUDNN.CUDNN_SOFTMAX_FAST)││     
   └──       goto FluxML/NNlibCUDA.jl#9                                                                ││     
   8 ─ %20 = NNlibCUDA.CUDNN_SOFTMAX_ACCURATE::Core.Const(CUDA.CUDNN.CUDNN_SOFTMAX_ACCURATE)
   └──       goto FluxML/NNlibCUDA.jl#9                                
   9 ┄ %22 = φ (#7 => %18, FluxML/NNlibCUDA.jl#8 => %20)::CUDA.CUDNN.cudnnSoftmaxAlgorithm_t           │      
   │   %23 = %new(NamedTuple{(:y, :mode, :algo), Tuple{CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CUDA.CUDNN.cudnnSoftmax
Mode_t, CUDA.CUDNN.cudnnSoftmaxAlgorithm_t}}, %10, $(QuoteNode(CUDA.CUDNN.CUDNN_SOFTMAX_MODE_CHANNEL)), %22)::Core.Partial
Struct(NamedTuple{(:y, :mode, :algo), Tuple{CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CUDA.CUDNN.cudnnSoftmaxMode_t, CUD
A.CUDNN.cudnnSoftmaxAlgorithm_t}}, Any[CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, Core.Const(CUDA.CUDNN.CUDNN_SOFTMAX_MOD
E_CHANNEL), CUDA.CUDNN.cudnnSoftmaxAlgorithm_t])
   │         invoke CUDA.CUDNN.var"#cudnnSoftmaxForwardWithDefaults##kw"()(%23::NamedTuple{(:y, :mode, :algo), Tuple{CuArr
ay{Float32, 4, CUDA.Mem.DeviceBuffer}, CUDA.CUDNN.cudnnSoftmaxMode_t, CUDA.CUDNN.cudnnSoftmaxAlgorithm_t}}, CUDA.CUDNN.cud
nnSoftmaxForwardWithDefaults::typeof(CUDA.CUDNN.cudnnSoftmaxForwardWithDefaults), %12::CuArray{Float32, 4, CUDA.Mem.Device
Buffer})::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}
74 └──       return y                                                               │      
Select a call to descend into or ↩ to ascend. [q]uit. [b]ookmark.
Toggles: [o]ptimize, [w]arn, [h]ide type-stable statements, [d]ebuginfo, [r]emarks, [e]ffects, [i]nlining costs, [t]ype an
notations, [s]yntax highlight for Source/LLVM/Native.
Show: [S]ource code, [A]ST, [T]yped code, [L]LVM IR, [N]ative code
Actions: [E]dit source code, [R]evise and redisplay
Advanced: dump [P]arams cache.
   %1 = invoke softmaxdims(::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer},::Int64)::Union{Nothing, NTuple{4, Int64}}
   %6 = invoke #_softmax!#44(::Int64,::#_softmax!,::CuArray{…},::CuArray{…})::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}
   %10 = invoke reshape(::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer},::NTuple{4, Int64})::…
   %12 = invoke reshape(::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer},::NTuple{4, Int64})::…
   %13 = invoke task_local_state!()::CUDA.TaskLocalState
•  %24 = invoke cudnnSoftmaxForwardWithDefaults##kw(::NamedTuple{…},::#cudnnSoftmaxForwardWithDefaults,::CuArray{…})::…
   ↩

I guess the thing that is slow and replaced is the pullback of softmax, not the forward part.

chengchingwen avatar Jan 10 '23 02:01 chengchingwen

Oh right, it was ∇softmax, which seemed slow here: https://github.com/FluxML/NNlib.jl/issues/513 Today, both seem faster on CUDNN.

mcabbott avatar Jan 10 '23 03:01 mcabbott