Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

`mapreduce` with accumulation inside is broken

Open Red-Portal opened this issue 1 year ago • 10 comments

Hi, the following use-case of mapreduce doesnt work:

gradient(randn(10)) do x
         y₀ = Float64[]
         ∑x = 0.0
         ys = mapreduce(vcat, x, 1:length(x); init = y₀) do xᵢ, r
             yᵢ = xᵢ.^2
             ∑x += xᵢ
             [yᵢ]
         end
         sum(ys) + ∑x
end

It seems the ∑x += xᵢ part is at fault here because with or without init it doesn't work:

(vcat, [[0.7008872619503351], [0.057800842475147274], [0.4508806424034738], [6.360461041381114], [8.229642138382558e-5], [0.43781177206525196], [1.7425577575168238], [0.8947064561514089], [0.678655434187004], [0.10421486484899199]])
(init = Float64[],)
ERROR: MethodError: no method matching iterate(::Nothing)

Closest candidates are:
  iterate(::Union{LinRange, StepRangeLen})
   @ Base range.jl:880
  iterate(::Union{LinRange, StepRangeLen}, ::Integer)
   @ Base range.jl:880
  iterate(::T) where T<:Union{Base.KeySet{<:Any, <:Dict}, Base.ValueIterator{<:Dict}}
   @ Base dict.jl:698
  ...

Stacktrace:
  [1] indexed_iterate(I::Nothing, i::Int64)
    @ Base ./tuple.jl:91
  [2] chain_rrule_kw(::Zygote.ZygoteRuleConfig{Zygote.Context{false}}, ::Function, ::NamedTuple{(:init,), Tuple{Vector{Float64}}}, ::Function, ::Function, ::Vararg{Any})
    @ Zygote ./REPL[7]:5
  [3] macro expansion
    @ ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:101 [inlined]
  [4] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::NamedTuple{(:init,), Tuple{Vector{Float64}}}, ::typeof(reduce), ::typeof(vcat), ::Vector{Vector{Float64}})
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:101
  [5] _pullback
    @ ./reducedim.jl:359 [inlined]
  [6] _pullback(::Zygote.Context{false}, ::Base.var"##mapreduce#801", ::Base.Pairs{Symbol, Vector{Float64}, Tuple{Symbol}, NamedTuple{(:init,), Tuple{Vector{Float64}}}}, ::typeof(mapreduce), ::var"#24#26", ::typeof(vcat), ::Vector{Float64}, ::UnitRange{Int64})
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:0
  [7] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
  [8] adjoint
    @ ~/.julia/packages/Zygote/4rucm/src/lib/lib.jl:203 [inlined]
  [9] _pullback
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66 [inlined]
 [10] _pullback
    @ ./reducedim.jl:359 [inlined]
 [11] _pullback
    @ ./REPL[8]:4 [inlined]
 [12] _pullback(ctx::Zygote.Context{false}, f::var"#23#25", args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:0
 [13] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:44
 [14] pullback
    @ ~/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:42 [inlined]
 [15] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:96
 [16] top-level scope
    @ REPL[8]:1

This used to work and got broken at some point. Is this an rrule problem? This currently works without problem on ReverseDiff and ForwardDiff.

Red-Portal avatar Aug 17 '23 23:08 Red-Portal

Related issues in Bijectors.jl and Turing.jl

Red-Portal avatar Aug 17 '23 23:08 Red-Portal

That code path has failed in the past because there are ambiguities in which rrules might apply for a given call. In this case I'm not sure if that is the culprit, however. I believe the problem is that ChainRules does not have a rrule for reduce(vcat, ...; init=...), yet somehow the has_chain_rrule detection logic is reporting it does.

ToucheSir avatar Aug 18 '23 17:08 ToucheSir

mapreduce(f, vcat, x, 1:length(x); init = y₀) could probably be plumbed to reduce(vcat, foldl(f, x, 1:length(x); init = y₀)). Perhaps that would be one way to work around this.

Note also that reduce(vcat, xs; init) and mapreduce(f, vcat, xs) are always pairwise, they never hit the magic fast path of reduce(vcat, xs).

mcabbott avatar Aug 19 '23 15:08 mcabbott

@torfjelde Is there a reason we compute the first element first and then use that to initialize mapreduce in Stacked?

Red-Portal avatar Aug 19 '23 21:08 Red-Portal

Is there a reason we compute the first element first and then use that to initialize mapreduce in Stacked?

Type-stability issues, in particular when combined with AD. Very often we'd run into instabilities without init, and so I believe this was a way to work around this (type-stability is quite crucial here, in particular with Zygote).

torfjelde avatar Aug 20 '23 09:08 torfjelde

About type stability, note that any call to a method with kwargs (whether they're provided in the call or not) will be type unstable unless there's a rrule defined for that particular method. In this case there is not.

ToucheSir avatar Aug 20 '23 14:08 ToucheSir

@torfjelde This issue is still persisting; any suggestions on how we should deal with this? Maybe just change the Stacked bijector implementation so that we don't hit this edge-case at all? I'm thinking computing the Jacobian and the transformation through two separate calls to mapreduce. Probably less efficient, but I don't see any other way unless this gets fixed. Also, we could expect the mapreduce(vcat, ...) fast path to kick in?

Red-Portal avatar Jun 04 '24 00:06 Red-Portal

aybe just change the Stacked bijector implementation so that we don't hit this edge-case at all?

Yep, that's what we should do imo.

Is stack applicable here?

Also, we could expect the mapreduce(vcat, ...) fast path to kick in?

Does such a fast-path exist?

torfjelde avatar Jun 04 '24 19:06 torfjelde

Does such a fast-path exist?

Oh sorry, I meant reduce(vcat). For this, I'm quoting @mcabbott 's reply:

note also that reduce(vcat, xs; init) and mapreduce(f, vcat, xs) are always pairwise, they never hit the magic fast path of reduce(vcat, xs).

Red-Portal avatar Jun 04 '24 20:06 Red-Portal

But yeah, I'd recommend that we just work around it by implementing something more specialized :+1:

torfjelde avatar Jun 04 '24 20:06 torfjelde