Zygote.jl
Zygote.jl copied to clipboard
missing rules for `repeat`
Zygote is missing some repeat
adjoints;
# this is OK
julia> gradient(x -> sum(repeat(x, outer=(2,2,2))), reshape(1:8, 2,2,2))
([8 8; 8 8]
[8 8; 8 8],)
# missing rule
julia> gradient(x -> sum(repeat(x, 2, 2, 2)), reshape(1:8, 2,2,2))
ERROR: Mutating arrays is not supported
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] (::Zygote.var"#372#373")(#unused#::Nothing)
@ Zygote ~/.julia/packages/Zygote/KpME9/src/lib/array.jl:58
[3] (::Zygote.var"#2249#back#374"{Zygote.var"#372#373"})(Δ::Nothing)
@ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[4] Pullback
@ ./abstractarraymath.jl:365 [inlined]
[5] (::typeof(∂(repeat_outer)))(Δ::FillArrays.Fill{Int64, 3, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
[6] Pullback
@ ./abstractarraymath.jl:327 [inlined]
[7] Pullback
@ ./abstractarraymath.jl:269 [inlined]
[8] (::typeof(∂(#repeat#1)))(Δ::FillArrays.Fill{Int64, 3, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
[9] Pullback
@ ./abstractarraymath.jl:267 [inlined]
[10] (::typeof(∂(repeat##kw)))(Δ::FillArrays.Fill{Int64, 3, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
[11] Pullback
@ ./abstractarraymath.jl:224 [inlined]
[12] (::typeof(∂(repeat)))(Δ::FillArrays.Fill{Int64, 3, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
[13] (::Zygote.var"#151#152"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing, Nothing}}, typeof(∂(repeat))})(Δ::FillArrays.Fill{Int64, 3, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/KpME9/src/lib/lib.jl:191
[14] #1682#back
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[15] Pullback
@ ./REPL[13]:1 [inlined]
[16] (::Zygote.var"#41#42"{typeof(∂(#27))})(Δ::Int64)
@ Zygote ~/.julia/packages/Zygote/KpME9/src/compiler/interface.jl:40
[17] gradient(f::Function, args::Base.ReshapedArray{Int64, 3, UnitRange{Int64}, Tuple{}})
@ Zygote ~/.julia/packages/Zygote/KpME9/src/compiler/interface.jl:49
[18] top-level scope
@ REPL[13]:1
Good catch. @CarloLucibello would you mind moving this over the ChainRules?
repeat adjoints are still in Zygote https://github.com/FluxML/Zygote.jl/blob/956cbcf3c572c0eb09c146189bb38b1b434634ff/src/lib/array.jl#L130 not sure why they were ported to ChainRules, probably just lack of time