NNlib.jl icon indicating copy to clipboard operation
NNlib.jl copied to clipboard

Use/export `LogExpFunctions.jl`?

Open devmotion opened this issue 4 years ago • 14 comments

The implementation of logsumexp in StatsFuns is quite optimized (see, e.g., https://github.com/JuliaStats/StatsFuns.jl/pull/97), it works with GPUs, is numerically more stable than the implementation in NNlib, and uses a one-pass algorithm.

I am wondering if NNlib should remove its own implementation and just reexport StatsFuns.logsumexp?

More generally, maybe it would make sense to unify some of the duplicate implementations in both packages of, e.g., softmax, softmax!, sigmoid, and softplus?

devmotion avatar Dec 22 '20 19:12 devmotion

I am wondering if NNlib should remove its own implementation and just reexport StatsFuns.logsumexp?

I guess we should. I'm just wary of adding another dependency, there have been already some complaints about latency (see #224). Any hope StatFuns could ditch its Rmath dependency?

More generally, maybe it would make sense to unify some of the duplicate implementations in both packages of, e.g., softmax, softmax!, sigmoid, and softplus?

Are all these gpu and AD friendly?

CarloLucibello avatar Dec 26 '20 09:12 CarloLucibello

Any hope StatFuns could ditch its Rmath dependency?

I don't know the exact plans of the maintainers, I think the plan is to remove the dependency eventually at some point. There are some issues regarding Rmath (e.g. JuliaStats/Distributions.jl#1509) and there was a discussion about moving the log/exp functions to a separate package (https://github.com/JuliaStats/StatsFuns.jl/issues/46).

Are all these gpu and AD friendly?

IIRC not, therefore I used the term unify here. I just checked and it seems softmax (without dims arguments), softplus, logit, and logistic work with CuArray but for some of them I get warnings such as

┌ Warning: calls to Base intrinsics might be GPU incompatible
│   exception =
│    You called log(x::Float32) in Base.Math at special/log.jl:289, maybe you intended to call log(x::Float32) in CUDA at /home/davwi492/.julia/packages/CUDA/YeS8q/src/de
vice/intrinsics/math.jl:73 instead?
│    Stacktrace:
│     [1] log at special/log.jl:289
│     [2] broadcast_kernel at /home/davwi492/.julia/packages/GPUArrays/jhRU7/src/host/broadcast.jl:59
└ @ GPUCompiler /home/davwi492/.julia/packages/GPUCompiler/uTpNx/src/irgen.jl:68

invsoftplus throws an error though, the problem can be reduced to

julia> map(log ∘ expm1, CUDA.rand(Float32, 5))
┌ Warning: calls to Base intrinsics might be GPU incompatible
│   exception =
│    You called log(x::Float32) in Base.Math at special/log.jl:289, maybe you intended to call log(x::Float32) in CUDA at /home/davwi492/.julia/packages/CUDA/YeS8q/src/de
vice/intrinsics/math.jl:73 instead?
│    Stacktrace:
│     [1] log at special/log.jl:289
│     [2] broadcast_kernel at /home/davwi492/.julia/packages/GPUArrays/jhRU7/src/host/broadcast.jl:59
└ @ GPUCompiler /home/davwi492/.julia/packages/GPUCompiler/uTpNx/src/irgen.jl:68
ERROR: InvalidIRError: compiling kernel broadcast_kernel(CUDA.CuKernelContext, CuDeviceArray{Float32,1,1}, Base.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64}},Bas
e.var"#62#63"{typeof(log),typeof(expm1)},Tuple{Base.Broadcast.Extruded{CuDeviceArray{Float32,1,1},Tuple{Bool},Tuple{Int64}}}}, Int64) resulted in invalid LLVM IR
Reason: unsupported call to the Julia runtime (call to jl_f_tuple)
Stacktrace:
 [1] expm1 at math.jl:367
 [2] JuliaStats/StatsFuns.jl#62 at operators.jl:875
 [3] _broadcast_getindex_evalf at broadcast.jl:648
 [4] _broadcast_getindex at broadcast.jl:621
 [5] getindex at broadcast.jl:575
 [6] broadcast_kernel at /home/davwi492/.julia/packages/GPUArrays/jhRU7/src/host/broadcast.jl:62
Stacktrace:
 [1] check_ir(::GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget,CUDA.CUDACompilerParams}, ::LLVM.Module) at /home/davwi492/.julia/packages/GPUCompiler/uTpNx/src/val
idation.jl:123
 [2] macro expansion at /home/davwi492/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:239 [inlined]
 [3] macro expansion at /home/davwi492/.julia/packages/TimerOutputs/ZmKD7/src/TimerOutput.jl:206 [inlined]
 [4] codegen(::Symbol, ::GPUCompiler.CompilerJob; libraries::Bool, deferred_codegen::Bool, optimize::Bool, strip::Bool, validate::Bool, only_entry::Bool) at /home/davwi49
2/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:237
 [5] compile(::Symbol, ::GPUCompiler.CompilerJob; libraries::Bool, deferred_codegen::Bool, optimize::Bool, strip::Bool, validate::Bool, only_entry::Bool) at /home/davwi49
2/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:39
 [6] compile at /home/davwi492/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:35 [inlined]
 [7] cufunction_compile(::GPUCompiler.FunctionSpec; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/davwi492/.julia/packages/CUDA/Y
eS8q/src/compiler/execution.jl:310
 [8] cufunction_compile(::GPUCompiler.FunctionSpec) at /home/davwi492/.julia/packages/CUDA/YeS8q/src/compiler/execution.jl:305
 [9] check_cache(::Dict{UInt64,Any}, ::Any, ::Any, ::GPUCompiler.FunctionSpec{GPUArrays.var"#broadcast_kernel#12",Tuple{CUDA.CuKernelContext,CuDeviceArray{Float32,1,1},Ba
se.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64}},Base.var"#62#63"{typeof(log),typeof(expm1)},Tuple{Base.Broadcast.Extruded{CuDeviceArray{Float32,1,1},Tuple{Bool}
,Tuple{Int64}}}},Int64}}, ::UInt64; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/davwi492/.julia/packages/GPUCompiler/uTpNx/src/
cache.jl:40
 [10] broadcast_kernel at /home/davwi492/.julia/packages/GPUArrays/jhRU7/src/host/broadcast.jl:60 [inlined]
 [11] cached_compilation at /home/davwi492/.julia/packages/GPUCompiler/uTpNx/src/cache.jl:65 [inlined]
 [12] cufunction(::GPUArrays.var"#broadcast_kernel#12", ::Type{Tuple{CUDA.CuKernelContext,CuDeviceArray{Float32,1,1},Base.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64}},Base.var"#62#63"{typeof(log),typeof(expm1)},Tuple{Base.Broadcast.Extruded{CuDeviceArray{Float32,1,1},Tuple{Bool},Tuple{Int64}}}},Int64}}; name::Nothing, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/davwi492/.julia/packages/CUDA/YeS8q/src/compiler/execution.jl:297
 [13] cufunction at /home/davwi492/.julia/packages/CUDA/YeS8q/src/compiler/execution.jl:294 [inlined]
 [14] #launch_heuristic#853 at /home/davwi492/.julia/packages/CUDA/YeS8q/src/gpuarrays.jl:19 [inlined]
 [15] launch_heuristic at /home/davwi492/.julia/packages/CUDA/YeS8q/src/gpuarrays.jl:17 [inlined]
 [16] copyto! at /home/davwi492/.julia/packages/GPUArrays/jhRU7/src/host/broadcast.jl:66 [inlined]
 [17] copyto! at ./broadcast.jl:886 [inlined]
 [18] copy at ./broadcast.jl:862 [inlined]
 [19] materialize(::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1},Nothing,Base.var"#62#63"{typeof(log),typeof(expm1)},Tuple{CuArray{Float32,1}}}) at ./broadcast.jl:837
 [20] map(::Function, ::CuArray{Float32,1}) at /home/davwi492/.julia/packages/GPUArrays/jhRU7/src/host/broadcast.jl:89
 [21] top-level scope at REPL[52]:1

devmotion avatar Dec 26 '20 13:12 devmotion

BTW StatsFuns already depends on ChainRulesCore implicitly via SpecialFunctions, so it seems custom ChainRules-based adjoints could be added to StatsFuns without introducing any additional dependencies.

devmotion avatar Dec 26 '20 14:12 devmotion

We should use https://github.com/JuliaStats/LogExpFunctions.jl, which doesn't depend on Rmath. Note that StatsFuns just re-exports the functions from LogExpFunctions. See https://github.com/FluxML/NNlib.jl/issues/331.

cossio avatar Jul 09 '21 07:07 cossio

Yes, this issue was one motivation for moving the functions to LogExpFunctions 🙂

devmotion avatar Jul 09 '21 07:07 devmotion

LogExpFunctions.jl should define the rrules. We could do it here, but the original repo is the natural place.

Also, if we need this, we'll need to define sepate implementations for CuArrays in NNlibCUDA

CarloLucibello avatar Jul 11 '21 12:07 CarloLucibello

FYI recently I added the ChainRules definitions to LogExpFunctions.

devmotion avatar Sep 22 '21 17:09 devmotion

Great, we can move some of the definitions there

DhairyaLGandhi avatar Sep 22 '21 17:09 DhairyaLGandhi

Which definitions? ChainRules? LogExpFunctions contains already derivatives for all functions defined in LogExpFunctions.

devmotion avatar Sep 22 '21 18:09 devmotion

FYI recently I added the ChainRules definitions to LogExpFunctions.

Not something that we typically pay much attention to (although we should!), but the rules themselves are differentiable?

CarloLucibello avatar Sep 23 '21 03:09 CarloLucibello

Nobody has tested it but they should be as they only involve basic functions or functions from LogExpFunctions for which rules are defined: https://github.com/JuliaStats/LogExpFunctions.jl/blob/master/src/chainrules.jl It might be more efficient though for in particular logsumexp and softmax to use custom second derivatives instead of differentiating through the rules. In general, I don't know if anyone has ever differentiated through rrules and frules (I would assume someone tried at least?).

devmotion avatar Nov 06 '21 20:11 devmotion

There are a few rules which have their own rules, as for sum here: https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/mapreduce.jl#L36

Those ones look likely to work, to me. Although perhaps you could find ways to make the second more efficient.

Why do they have Ωcopy = copy(Ω) though?

mcabbott avatar Nov 06 '21 20:11 mcabbott

Mutation of the primal result of softmax leads to an incorrect pullback. The copy ensures that the pullback is always correct, regardless of downstream computations.

devmotion avatar Nov 06 '21 22:11 devmotion

Sure, I guess I mean, did this come up somewhere?

Every rule in https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/arraymath.jl (except for + & -) closes over things without preventative copies. So they rely on you not mutating arguments/results elsewhere. Changing that seems like it would roughly double memory usage.

Looking in the docs quickly, I don't actually see mention of such questions. Maybe @oxinabox has thoughts?

mcabbott avatar Nov 06 '21 22:11 mcabbott