ChainRulesCore.jl icon indicating copy to clipboard operation
ChainRulesCore.jl copied to clipboard

Errror in accumulate when I have one argument as a tuple

Open pevnak opened this issue 1 year ago • 2 comments

Hello,

I have been for educational purposes implementing RNN by hand and wanted to be fancy and use accumulate instead of recursion or for rule. But I run into an error, when one of the operands in accumulate is tuple. A have carved out an MWE, which would look like this

using Zygote

x = [randn(Float32, 2) for i in 1:3]
h = randn(Float32, 2)


function f(α, h, x)
	o = accumulate(x, init = h) do h, x
		α * h + x
	end
end

function g(α, h, x)
	o = accumulate(x, init = (h, x[1])) do (h,_),x
		(α * h + x, x)
	end
	first.(o)
end

gradient(α -> sum(sum(g(α, h, x))), 1f0)[1]
gradient(α -> sum(sum(f(α, h, x))), 1f0)[1]

While computing gradient of f succeeds, computing gradient of g crashes with

julia> gradient(α -> sum(sum(g(α, h, x))), 1f0)[1]
ERROR: MethodError: no method matching construct(::Type{Any}, ::Tuple{FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, ChainRulesCore.NoTangent})

Closest candidates are:
  construct(::Type{T}, ::T) where T<:Tuple
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/6DiyF/src/tangent_types/structural_tangent.jl:251
  construct(::Type{T}, ::NamedTuple{L}) where {T, L}
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/6DiyF/src/tangent_types/structural_tangent.jl:235

Stacktrace:
  [1] +(a::ChainRulesCore.Tangent{Tuple{…}, Tuple{…}}, d::ChainRulesCore.Tangent{Any, Tuple{…}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/6DiyF/src/tangent_arithmetic.jl:142
  [2] (::ChainRules.var"#1699#1702")(::Tuple{…}, ::Tuple{…})
    @ ChainRules ~/.julia/packages/ChainRules/FLsQJ/src/rulesets/Base/mapreduce.jl:541
  [3] iterate(itr::Base.Iterators.Accumulate)
    @ Base.Iterators ./iterators.jl:589 [inlined]
  [4] collect_to!
    @ ./array.jl:892 [inlined]
  [5] collect_to_with_first!
    @ ./array.jl:870 [inlined]
  [6] _collect(c::Any, itr::Any, ::Base.EltypeUnknown, isz::Union{Base.HasLength, Base.HasShape})
    @ Base ./array.jl:864 [inlined]
  [7] collect(itr::Base.Generator)
    @ Base ./array.jl:759 [inlined]
  [8] #accumulate#893
    @ ./accumulate.jl:281 [inlined]
  [9] accumulate
    @ ./accumulate.jl:278 [inlined]
 [10] (::ChainRules.var"#decumulate#1701"{…})(dy::Vector{…})
    @ ChainRules ~/.julia/packages/ChainRules/FLsQJ/src/rulesets/Base/mapreduce.jl:540
 [11] ZBack
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:211 [inlined]
 [12] (::Zygote.var"#kw_zpullback#53"{ChainRules.var"#decumulate#1701"{…}})(dy::Vector{Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:237
 [13] g
    @ ./REPL[43]:2 [inlined]
 [14] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::FillArrays.Fill{FillArrays.Fill{…}, 1, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [15] #53
    @ ./REPL[44]:1 [inlined]
 [16] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [17] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:91
 [18] gradient(f::Function, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:148
 [19] top-level scope
    @ REPL[44]:1
Some type information was truncated. Use `show(err)` to see complete types.

Julia and environment

julia> versioninfo()
Julia Version 1.10.0-rc2
Commit dbb9c46795b (2023-12-03 15:25 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (x86_64-apple-darwin22.4.0)
  CPU: 8 × Intel(R) Core(TM) i5-8279U CPU @ 2.40GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, skylake)
  Threads: 1 on 8 virtual cores

(tmp) pkg> st
Status `/private/tmp/Project.toml`
  [082447d4] ChainRules v1.63.0
  [d360d2e6] ChainRulesCore v1.21.1
  [26cc04aa] FiniteDifferences v0.12.31
  [587475ba] Flux v0.14.11
  [3bd65402] Optimisers v0.3.2
  [eeda0dda] SafeTensors v1.0.0
  [2913bbd2] StatsBase v0.34.2
  [e88e6eb3] Zygote v0.6.69

Thanks for help

pevnak avatar Feb 14 '24 20:02 pevnak

Zygote is constructing tangents that enter the decumulate pullback via wrap_chainrules_output. in this case its hitting the method for Union{Tuple,NamedTuple} which is interesting, because I think it should be using the method for Tuple.

I think this could be fixed by making sure wrap_chainrules_output returns a StructuralTangent... or at least if in zygote I do:

@inline function wrap_chainrules_input(dxs::Union{Tuple, NamedTuple})
  xp = map(wrap_chainrules_input, dxs)
  # This produces Tangent{Any} since it does not get to see the primal, `x`.
  # ChainRulesCore.Tangent{Any, typeof(xp)}(xp) -- comment this out and replace by line below
  ChainRulesCore.StructuralTangent{typeof(xp)}(xp)
end

things seem to work out

nmheim avatar Feb 15 '24 08:02 nmheim

Same error with https://github.com/JuliaDiff/ChainRules.jl/pull/569, FWIW.

Not certain this is relevant, but notice the similarity to this:

julia> accumulate(=>, (1,2,3))
(1, 1 => 2, (1 => 2) => 3)

julia> accumulate(=>, [1,2,3])
ERROR: MethodError: Cannot `convert` an object of type Int64 to an object of type Pair{Int64, Int64}

and that this gradient works with x::Tuple:

julia> gradient(α -> sum(sum(g(α, h, Tuple(x)))), 1f0)[1]
15.059713f0

julia> gradient(α -> sum(sum(g(α, h, x))), 1f0)[1]  # with x::Vector as above
ERROR: MethodError: no method matching construct(::Type{Any}, ::Tuple{FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, ChainRulesCore.NoTangent})

mcabbott avatar Feb 20 '24 15:02 mcabbott