Depend on ChainRulesCore?
I've recently ran into this error in the wild, see MWE
julia> using Zygote
julia> using StatsBase
julia> gradient(v->sum(AnalyticWeights(v)), rand(3))
ERROR: Need an adjoint for constructor AnalyticWeights{Float64, Float64, Vector{Float64}}. Gradient is of type FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] (::Zygote.Jnew{AnalyticWeights{Float64, Float64, Vector{Float64}}, Vector{Any}, false})(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
@ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/lib/lib.jl:354
[3] (::Zygote.var"#1812#back#229"{Zygote.Jnew{AnalyticWeights{Float64, Float64, Vector{Float64}}, Vector{Any}, false}})(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[4] Pullback
@ ~/JuliaEnvs/PortfolioNets.jl/dev/StatsBase/src/weights.jl:13 [inlined]
[5] (::typeof(∂(AnalyticWeights{Float64, Float64, Vector{Float64}})))(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
@ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/interface2.jl:0
[6] Pullback
@ ~/JuliaEnvs/PortfolioNets.jl/dev/StatsBase/src/weights.jl:13 [inlined]
[7] (::typeof(∂(AnalyticWeights)))(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
@ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/interface2.jl:0
[8] Pullback
@ ~/JuliaEnvs/PortfolioNets.jl/dev/StatsBase/src/weights.jl:16 [inlined]
[9] (::typeof(∂(AnalyticWeights)))(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
@ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/interface2.jl:0
[10] Pullback
@ ./REPL[4]:1 [inlined]
[11] (::typeof(∂(#5)))(Δ::Float64)
@ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/interface2.jl:0
[12] (::Zygote.var"#50#51"{typeof(∂(#5))})(Δ::Float64)
@ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/interface.jl:41
[13] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/interface.jl:76
[14] top-level scope
@ REPL[4]:1
I've fixed it by pirating
function ChainRulesCore.rrule(::Type{StatsBase.AnalyticWeights}, values)
AnalyticWeights_pullback(ȳ::AbstractArray) = (NoTangent(), ȳ)
AnalyticWeights_pullback(ȳ::Tangent) = (NoTangent(), ȳ.values)
AnalyticWeights_pullback(ȳ::AbstractThunk) = (NoTangent(), unthunk(ȳ))
return AnalyticWeights(values), AnalyticWeights_pullback
end
which solves the problem.
I wonder whether a PR adding this would be welcome? It would need to add a dependency on ChainRulesCore, which is quite lightweight (order of 0.1s precompile time IIRC)
(order of 0.1s precompile time IIRC)
More importantly about 0.05s load time.
Precompile time is cheap.
I am surprised there is not already a transitive dependency on ChainRulesCore. But indeed there isn't.
If https://github.com/JuliaLang/Statistics.jl/issues/4 is done then for newer versions of Julia we will be able to define this in ChainRules.jl with the other rules that we have for the Statistics stdlib.
The rule might actually require a little care.
Since the structural Tangent might in theory contain the values or sum or both (and if both they may or may not be consistent).
Depending on what path it has taken to get here.
AnalyticWeights_pullback(ȳ::Tangent) = (NoTangent(), zeros(values) .+ ȳ.values .+ ȳ.sum)
Also i think the unthunk case should either not be unthunking or should be calling AnalyticWeights_pullback after.
(The later of which is a case of needing to be causeful about calling functions defined locally)
Yes I'd rather have ChainRules depend on Statistics once functions are moved there. Otherwise the dependency on ChainRulesCore will make that move impossible.
I am surprised there is not already a transitive dependency on ChainRulesCore. But indeed there isn't.
It is in StatsBase >= 0.33.11 through LogExpFunctions.