Zygote.jl
Zygote.jl copied to clipboard
Error from gradient of `vcat(x...)` - appeared in v0.6.45
In https://github.com/GTorlai/PastaQ.jl/issues/300#issuecomment-1525720876 a bug was detected that I have found to stem from a problem in the differentiation of vcat. I created a minimal example to reproduce the error:
using Zygote
function loss(theta)
x1 = vcat([theta], 5)
x2 = vcat(x1...)
return x2[1]
end
println(gradient(loss, 1))
ERROR: LoadError: MethodError: no method matching (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}})(::Tuple{Float64})
Closest candidates are:
(::ChainRulesCore.ProjectTo{AbstractArray})(::Union{LinearAlgebra.Adjoint{T, var"#s886"}, LinearAlgebra.Transpose{T, var"#s886"}} where {T, var"#s886"<:(AbstractVector)}) at ~/.julia/packages/ChainRulesCore/a4mIA/src/projection.jl:247
(::ChainRulesCore.ProjectTo{AbstractArray})(::AbstractArray{<:ChainRulesCore.AbstractZero}) at ~/.julia/packages/ChainRulesCore/a4mIA/src/projection.jl:244
(::ChainRulesCore.ProjectTo{AbstractArray})(::AbstractArray{S, M}) where {S, M} at ~/.julia/packages/ChainRulesCore/a4mIA/src/projection.jl:219
...
Stacktrace:
[1] (::ChainRules.var"#1413#1419"{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Tuple{UnitRange{Int64}}, ChainRulesCore.Tangent{Any, Tuple{Float64, Float64}}})()
@ ChainRules ~/.julia/packages/ChainRules/aKxNz/src/rulesets/Base/array.jl:310
[2] unthunk
@ ~/.julia/packages/ChainRulesCore/a4mIA/src/tangent_types/thunks.jl:204 [inlined]
[3] unthunk(x::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1413#1419"{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Tuple{UnitRange{Int64}}, ChainRulesCore.Tangent{Any, Tuple{Float64, Float64}}}}, ChainRules.var"#1412#1418"{Tuple{UnitRange{Int64}}, ChainRulesCore.Tangent{Any, Tuple{Float64, Float64}}}})
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/a4mIA/src/tangent_types/thunks.jl:237
[4] wrap_chainrules_output
@ ~/.julia/packages/Zygote/SuKWp/src/compiler/chainrules.jl:110 [inlined]
[5] map
@ ./tuple.jl:223 [inlined]
[6] wrap_chainrules_output
@ ~/.julia/packages/Zygote/SuKWp/src/compiler/chainrules.jl:111 [inlined]
[7] (::Zygote.ZBack{ChainRules.var"#vcat_pullback#1415"{Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{Int64}, Tuple{}}, Val{1}}})(dy::Tuple{Float64, Float64})
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/chainrules.jl:211
[8] Pullback
@ ~/workprojects/education/julia/pastaq/break.jl:3 [inlined]
[9] (::Zygote.Pullback{Tuple{typeof(loss), Int64}, Tuple{Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Int64}, Zygote.ZBack{ChainRules.var"#vcat_pullback#1415"{Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{}, Tuple{}}, Val{1}}}}}, Zygote.ZBack{ChainRules.var"#vcat_pullback#1415"{Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{Int64}, Tuple{}}, Val{1}}}, Zygote.ZBack{ChainRules.var"#vect_pullback#1369"{1, Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}, Zygote.Pullback{Tuple{typeof(Zygote.literal_getindex), Vector{Int64}, Val{1}}, Tuple{Zygote.var"#2571#back#528"{Zygote.var"#538#540"{1, Int64, Vector{Int64}, Tuple{Int64}}}}}}})(Δ::Int64)
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
[10] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(loss), Int64}, Tuple{Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Int64}, Zygote.ZBack{ChainRules.var"#vcat_pullback#1415"{Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{}, Tuple{}}, Val{1}}}}}, Zygote.ZBack{ChainRules.var"#vcat_pullback#1415"{Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{Int64}, Tuple{}}, Val{1}}}, Zygote.ZBack{ChainRules.var"#vect_pullback#1369"{1, Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}, Zygote.Pullback{Tuple{typeof(Zygote.literal_getindex), Vector{Int64}, Val{1}}, Tuple{Zygote.var"#2571#back#528"{Zygote.var"#538#540"{1, Int64, Vector{Int64}, Tuple{Int64}}}}}}}})(Δ::Int64)
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:45
[11] gradient(::Function, ::Int64, ::Vararg{Int64})
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:97
this code used to work in version [email protected] but does not work as early as [email protected] until [email protected].
0.6.45 is when we switched over to ChainRules for the cat functions: https://github.com/FluxML/Zygote.jl/pull/1277. TBD whether ChainRules projection isn't being flexible enough or if Zygote is passing invalid inputs to it.
I think this is https://github.com/FluxML/Zygote.jl/issues/599, x1... makes a Tuple but the gradient of x1 ought to be an array. It's been worked around in some cases (e..g with _project) but not all.
I am running into this issue while trying to implement DenseNet. Since vcat is one of the only non-mutating ways to append elements to arrays, this is a blocker for that. Is there a workaround or a fix for this? I confirmed that it was working on 0.6.44, but the error appears on versions higher than that.
Can you simplify the example, or make other ones? ~~Perhaps the splat isn't the right diagnosis, as things like this seem fine:~~
julia> gradient([2, 3.0]) do x
vcat(x...)[1]
end
([1.0, 0.0],)
julia> gradient([2, 3.0]) do x
vcat(x..., 4)[1]
end
([1.0, 0.0],)
No, I think those are getting fixed... pullback avoids a final _project on the answer of gradient, here the splat clearly makes the tuple:
julia> pullback([2, 3.0]) do x
vcat(x...)[1]
end[2](1.0)
((1.0, 0.0),)
The rrule involved cannot fix this, it sees and returns individual arguments:
julia> using ChainRules, ChainRulesCore
julia> rrule(vcat, 3/4, 4/5)[2]([6.6, 7.7])
(NoTangent(), InplaceableThunk(ChainRules.var"#..., Thunk(ChainRules.var"#...)), InplaceableThunk(ChainRules.var"#..., Thunk(ChainRules.var"#...)))
julia> unthunk.(ans)
(NoTangent(), 6.6, 7.7)
Differentiating through this is what causes the error for me:
function (m::DenseBlock)(x)
input = [x]
for layer in m.layers
x = layer(input)
input = vcat(input, [x])
end
return cat_channels(input...)
end
This is the only place vcat is used in my code. The layers are mostly simple Chains with Convs and BatchNorms, in case that is useful information. It does seem to suggest that the splat is not the only issue.
I think we should be implementing DenseNet differently anyhow (toss up a PR if you want some ideas there), so this shouldn't block Metalhead at least.