Zygote.jl
Zygote.jl copied to clipboard
`mapreduce` with accumulation inside is broken
Hi, the following use-case of mapreduce
doesnt work:
gradient(randn(10)) do x
y₀ = Float64[]
∑x = 0.0
ys = mapreduce(vcat, x, 1:length(x); init = y₀) do xᵢ, r
yᵢ = xᵢ.^2
∑x += xᵢ
[yᵢ]
end
sum(ys) + ∑x
end
It seems the ∑x += xᵢ
part is at fault here because with or without init
it doesn't work:
(vcat, [[0.7008872619503351], [0.057800842475147274], [0.4508806424034738], [6.360461041381114], [8.229642138382558e-5], [0.43781177206525196], [1.7425577575168238], [0.8947064561514089], [0.678655434187004], [0.10421486484899199]])
(init = Float64[],)
ERROR: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
iterate(::Union{LinRange, StepRangeLen})
@ Base range.jl:880
iterate(::Union{LinRange, StepRangeLen}, ::Integer)
@ Base range.jl:880
iterate(::T) where T<:Union{Base.KeySet{<:Any, <:Dict}, Base.ValueIterator{<:Dict}}
@ Base dict.jl:698
...
Stacktrace:
[1] indexed_iterate(I::Nothing, i::Int64)
@ Base ./tuple.jl:91
[2] chain_rrule_kw(::Zygote.ZygoteRuleConfig{Zygote.Context{false}}, ::Function, ::NamedTuple{(:init,), Tuple{Vector{Float64}}}, ::Function, ::Function, ::Vararg{Any})
@ Zygote ./REPL[7]:5
[3] macro expansion
@ ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:101 [inlined]
[4] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::NamedTuple{(:init,), Tuple{Vector{Float64}}}, ::typeof(reduce), ::typeof(vcat), ::Vector{Vector{Float64}})
@ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:101
[5] _pullback
@ ./reducedim.jl:359 [inlined]
[6] _pullback(::Zygote.Context{false}, ::Base.var"##mapreduce#801", ::Base.Pairs{Symbol, Vector{Float64}, Tuple{Symbol}, NamedTuple{(:init,), Tuple{Vector{Float64}}}}, ::typeof(mapreduce), ::var"#24#26", ::typeof(vcat), ::Vector{Float64}, ::UnitRange{Int64})
@ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:0
[7] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:838
[8] adjoint
@ ~/.julia/packages/Zygote/4rucm/src/lib/lib.jl:203 [inlined]
[9] _pullback
@ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66 [inlined]
[10] _pullback
@ ./reducedim.jl:359 [inlined]
[11] _pullback
@ ./REPL[8]:4 [inlined]
[12] _pullback(ctx::Zygote.Context{false}, f::var"#23#25", args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:0
[13] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:44
[14] pullback
@ ~/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:42 [inlined]
[15] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:96
[16] top-level scope
@ REPL[8]:1
This used to work and got broken at some point. Is this an rrule
problem? This currently works without problem on ReverseDiff
and ForwardDiff
.
Related issues in Bijectors.jl and Turing.jl
That code path has failed in the past because there are ambiguities in which rrules might apply for a given call. In this case I'm not sure if that is the culprit, however. I believe the problem is that ChainRules does not have a rrule
for reduce(vcat, ...; init=...)
, yet somehow the has_chain_rrule
detection logic is reporting it does.
mapreduce(f, vcat, x, 1:length(x); init = y₀)
could probably be plumbed to reduce(vcat, foldl(f, x, 1:length(x); init = y₀))
. Perhaps that would be one way to work around this.
Note also that reduce(vcat, xs; init)
and mapreduce(f, vcat, xs)
are always pairwise, they never hit the magic fast path of reduce(vcat, xs)
.
@torfjelde Is there a reason we compute the first element first and then use that to initialize mapreduce
in Stacked
?
Is there a reason we compute the first element first and then use that to initialize mapreduce in Stacked?
Type-stability issues, in particular when combined with AD. Very often we'd run into instabilities without init
, and so I believe this was a way to work around this (type-stability is quite crucial here, in particular with Zygote).
About type stability, note that any call to a method with kwargs (whether they're provided in the call or not) will be type unstable unless there's a rrule defined for that particular method. In this case there is not.
@torfjelde This issue is still persisting; any suggestions on how we should deal with this? Maybe just change the Stacked
bijector implementation so that we don't hit this edge-case at all? I'm thinking computing the Jacobian and the transformation through two separate calls to mapreduce. Probably less efficient, but I don't see any other way unless this gets fixed. Also, we could expect the mapreduce(vcat, ...)
fast path to kick in?
aybe just change the Stacked bijector implementation so that we don't hit this edge-case at all?
Yep, that's what we should do imo.
Is stack
applicable here?
Also, we could expect the mapreduce(vcat, ...) fast path to kick in?
Does such a fast-path exist?
Does such a fast-path exist?
Oh sorry, I meant reduce(vcat)
. For this, I'm quoting @mcabbott 's reply:
note also that
reduce(vcat, xs; init)
andmapreduce(f, vcat, xs)
are always pairwise, they never hit the magic fast path ofreduce(vcat, xs)
.
But yeah, I'd recommend that we just work around it by implementing something more specialized :+1: