Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Error from gradient of `vcat(x...)` - appeared in v0.6.45

Open danielalcalde opened this issue 2 years ago • 6 comments

In https://github.com/GTorlai/PastaQ.jl/issues/300#issuecomment-1525720876 a bug was detected that I have found to stem from a problem in the differentiation of vcat. I created a minimal example to reproduce the error:

using Zygote
function loss(theta)
    x1 = vcat([theta], 5)
    x2 = vcat(x1...)
    return x2[1]
end
println(gradient(loss, 1))
ERROR: LoadError: MethodError: no method matching (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}})(::Tuple{Float64})
Closest candidates are:
  (::ChainRulesCore.ProjectTo{AbstractArray})(::Union{LinearAlgebra.Adjoint{T, var"#s886"}, LinearAlgebra.Transpose{T, var"#s886"}} where {T, var"#s886"<:(AbstractVector)}) at ~/.julia/packages/ChainRulesCore/a4mIA/src/projection.jl:247
  (::ChainRulesCore.ProjectTo{AbstractArray})(::AbstractArray{<:ChainRulesCore.AbstractZero}) at ~/.julia/packages/ChainRulesCore/a4mIA/src/projection.jl:244
  (::ChainRulesCore.ProjectTo{AbstractArray})(::AbstractArray{S, M}) where {S, M} at ~/.julia/packages/ChainRulesCore/a4mIA/src/projection.jl:219
  ...
Stacktrace:
  [1] (::ChainRules.var"#1413#1419"{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Tuple{UnitRange{Int64}}, ChainRulesCore.Tangent{Any, Tuple{Float64, Float64}}})()
    @ ChainRules ~/.julia/packages/ChainRules/aKxNz/src/rulesets/Base/array.jl:310
  [2] unthunk
    @ ~/.julia/packages/ChainRulesCore/a4mIA/src/tangent_types/thunks.jl:204 [inlined]
  [3] unthunk(x::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1413#1419"{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Tuple{UnitRange{Int64}}, ChainRulesCore.Tangent{Any, Tuple{Float64, Float64}}}}, ChainRules.var"#1412#1418"{Tuple{UnitRange{Int64}}, ChainRulesCore.Tangent{Any, Tuple{Float64, Float64}}}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/a4mIA/src/tangent_types/thunks.jl:237
  [4] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/SuKWp/src/compiler/chainrules.jl:110 [inlined]
  [5] map
    @ ./tuple.jl:223 [inlined]
  [6] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/SuKWp/src/compiler/chainrules.jl:111 [inlined]
  [7] (::Zygote.ZBack{ChainRules.var"#vcat_pullback#1415"{Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{Int64}, Tuple{}}, Val{1}}})(dy::Tuple{Float64, Float64})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/chainrules.jl:211
  [8] Pullback
    @ ~/workprojects/education/julia/pastaq/break.jl:3 [inlined]
  [9] (::Zygote.Pullback{Tuple{typeof(loss), Int64}, Tuple{Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Int64}, Zygote.ZBack{ChainRules.var"#vcat_pullback#1415"{Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{}, Tuple{}}, Val{1}}}}}, Zygote.ZBack{ChainRules.var"#vcat_pullback#1415"{Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{Int64}, Tuple{}}, Val{1}}}, Zygote.ZBack{ChainRules.var"#vect_pullback#1369"{1, Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}, Zygote.Pullback{Tuple{typeof(Zygote.literal_getindex), Vector{Int64}, Val{1}}, Tuple{Zygote.var"#2571#back#528"{Zygote.var"#538#540"{1, Int64, Vector{Int64}, Tuple{Int64}}}}}}})(Δ::Int64)
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
 [10] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(loss), Int64}, Tuple{Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Int64}, Zygote.ZBack{ChainRules.var"#vcat_pullback#1415"{Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{}, Tuple{}}, Val{1}}}}}, Zygote.ZBack{ChainRules.var"#vcat_pullback#1415"{Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{Int64}, Tuple{}}, Val{1}}}, Zygote.ZBack{ChainRules.var"#vect_pullback#1369"{1, Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}, Zygote.Pullback{Tuple{typeof(Zygote.literal_getindex), Vector{Int64}, Val{1}}, Tuple{Zygote.var"#2571#back#528"{Zygote.var"#538#540"{1, Int64, Vector{Int64}, Tuple{Int64}}}}}}}})(Δ::Int64)
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:45
 [11] gradient(::Function, ::Int64, ::Vararg{Int64})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:97

this code used to work in version [email protected] but does not work as early as [email protected] until [email protected].

danielalcalde avatar Apr 28 '23 14:04 danielalcalde

0.6.45 is when we switched over to ChainRules for the cat functions: https://github.com/FluxML/Zygote.jl/pull/1277. TBD whether ChainRules projection isn't being flexible enough or if Zygote is passing invalid inputs to it.

ToucheSir avatar Apr 29 '23 15:04 ToucheSir

I think this is https://github.com/FluxML/Zygote.jl/issues/599, x1... makes a Tuple but the gradient of x1 ought to be an array. It's been worked around in some cases (e..g with _project) but not all.

mcabbott avatar Apr 30 '23 14:04 mcabbott

I am running into this issue while trying to implement DenseNet. Since vcat is one of the only non-mutating ways to append elements to arrays, this is a blocker for that. Is there a workaround or a fix for this? I confirmed that it was working on 0.6.44, but the error appears on versions higher than that.

theabhirath avatar May 24 '23 09:05 theabhirath

Can you simplify the example, or make other ones? ~~Perhaps the splat isn't the right diagnosis, as things like this seem fine:~~

julia> gradient([2, 3.0]) do x
         vcat(x...)[1]
       end
([1.0, 0.0],)

julia> gradient([2, 3.0]) do x
         vcat(x..., 4)[1]
       end
([1.0, 0.0],)

No, I think those are getting fixed... pullback avoids a final _project on the answer of gradient, here the splat clearly makes the tuple:

julia> pullback([2, 3.0]) do x
         vcat(x...)[1]
       end[2](1.0)
((1.0, 0.0),)

The rrule involved cannot fix this, it sees and returns individual arguments:

julia> using ChainRules, ChainRulesCore

julia> rrule(vcat, 3/4, 4/5)[2]([6.6, 7.7])
(NoTangent(), InplaceableThunk(ChainRules.var"#..., Thunk(ChainRules.var"#...)), InplaceableThunk(ChainRules.var"#..., Thunk(ChainRules.var"#...)))

julia> unthunk.(ans)
(NoTangent(), 6.6, 7.7)

mcabbott avatar May 24 '23 13:05 mcabbott

Differentiating through this is what causes the error for me:

function (m::DenseBlock)(x)
    input = [x]
    for layer in m.layers
        x = layer(input)
        input = vcat(input, [x])
    end
    return cat_channels(input...)
end

This is the only place vcat is used in my code. The layers are mostly simple Chains with Convs and BatchNorms, in case that is useful information. It does seem to suggest that the splat is not the only issue.

theabhirath avatar May 24 '23 16:05 theabhirath

I think we should be implementing DenseNet differently anyhow (toss up a PR if you want some ideas there), so this shouldn't block Metalhead at least.

ToucheSir avatar May 24 '23 16:05 ToucheSir