StatsBase.jl icon indicating copy to clipboard operation
StatsBase.jl copied to clipboard

Depend on ChainRulesCore?

Open mzgubic opened this issue 4 years ago • 3 comments

I've recently ran into this error in the wild, see MWE

julia> using Zygote

julia> using StatsBase

julia> gradient(v->sum(AnalyticWeights(v)), rand(3))
ERROR: Need an adjoint for constructor AnalyticWeights{Float64, Float64, Vector{Float64}}. Gradient is of type FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.Jnew{AnalyticWeights{Float64, Float64, Vector{Float64}}, Vector{Any}, false})(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/lib/lib.jl:354
  [3] (::Zygote.var"#1812#back#229"{Zygote.Jnew{AnalyticWeights{Float64, Float64, Vector{Float64}}, Vector{Any}, false}})(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [4] Pullback
    @ ~/JuliaEnvs/PortfolioNets.jl/dev/StatsBase/src/weights.jl:13 [inlined]
  [5] (::typeof(∂(AnalyticWeights{Float64, Float64, Vector{Float64}})))(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/interface2.jl:0
  [6] Pullback
    @ ~/JuliaEnvs/PortfolioNets.jl/dev/StatsBase/src/weights.jl:13 [inlined]
  [7] (::typeof(∂(AnalyticWeights)))(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/interface2.jl:0
  [8] Pullback
    @ ~/JuliaEnvs/PortfolioNets.jl/dev/StatsBase/src/weights.jl:16 [inlined]
  [9] (::typeof(∂(AnalyticWeights)))(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/interface2.jl:0
 [10] Pullback
    @ ./REPL[4]:1 [inlined]
 [11] (::typeof(∂(#5)))(Δ::Float64)
    @ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#50#51"{typeof(∂(#5))})(Δ::Float64)
    @ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/interface.jl:41
 [13] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/interface.jl:76
 [14] top-level scope
    @ REPL[4]:1

I've fixed it by pirating

function ChainRulesCore.rrule(::Type{StatsBase.AnalyticWeights}, values)
    AnalyticWeights_pullback(ȳ::AbstractArray) = (NoTangent(), ȳ)
    AnalyticWeights_pullback(ȳ::Tangent) = (NoTangent(), ȳ.values)
    AnalyticWeights_pullback(ȳ::AbstractThunk) = (NoTangent(), unthunk(ȳ))
    return AnalyticWeights(values), AnalyticWeights_pullback
end

which solves the problem.

I wonder whether a PR adding this would be welcome? It would need to add a dependency on ChainRulesCore, which is quite lightweight (order of 0.1s precompile time IIRC)

mzgubic avatar Sep 06 '21 15:09 mzgubic

(order of 0.1s precompile time IIRC)

More importantly about 0.05s load time. Precompile time is cheap.

I am surprised there is not already a transitive dependency on ChainRulesCore. But indeed there isn't.


If https://github.com/JuliaLang/Statistics.jl/issues/4 is done then for newer versions of Julia we will be able to define this in ChainRules.jl with the other rules that we have for the Statistics stdlib.


The rule might actually require a little care. Since the structural Tangent might in theory contain the values or sum or both (and if both they may or may not be consistent). Depending on what path it has taken to get here. AnalyticWeights_pullback(ȳ::Tangent) = (NoTangent(), zeros(values) .+ ȳ.values .+ ȳ.sum)

Also i think the unthunk case should either not be unthunking or should be calling AnalyticWeights_pullback after. (The later of which is a case of needing to be causeful about calling functions defined locally)

oxinabox avatar Sep 06 '21 21:09 oxinabox

Yes I'd rather have ChainRules depend on Statistics once functions are moved there. Otherwise the dependency on ChainRulesCore will make that move impossible.

nalimilan avatar Sep 08 '21 20:09 nalimilan

I am surprised there is not already a transitive dependency on ChainRulesCore. But indeed there isn't.

It is in StatsBase >= 0.33.11 through LogExpFunctions.

devmotion avatar Nov 26 '21 19:11 devmotion