Gradient of `reshape(::Array{Bool}, ...)` does not handle thunks
Originally:
julia> using Flux
julia> let e = Embedding(2=>2)
x = Flux.onehotbatch([1 2; 2 1], 1:2)
# x = Array(x) # similar error with Array or OneHotArray
Flux.gradient(m -> sum(abs2, m(x)), e)
end
ERROR: MethodError: no method matching reshape(::Nothing, ::Tuple{Int64, Int64, Int64})
Edit, see Zygote-only MWE below
Closest candidates are:
reshape(::ChainRulesCore.AbstractThunk, ::Any...)
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/U6wNx/src/tangent_types/thunks.jl:62
reshape(::Array{T, M}, ::NTuple{N, Int64}) where {T, N, M}
@ Base reshapedarray.jl:40
reshape(::BitArray{N}, ::NTuple{N, Int64}) where N
@ Base bitarray.jl:479
...
Stacktrace:
[1] (::Zygote.var"#617#621"{OneHotArrays.OneHotArray{UInt32, 2, 3, Matrix{UInt32}}, Tuple{Int64, Colon}})(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/aLHDR/src/lib/array.jl:107
[2] (::Zygote.var"#2783#back#623"{Zygote.var"#617#621"{…}})(Δ::ChainRulesCore.Thunk{ChainRules.var"#546#549"{…}})
@ Zygote ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72
[3] Embedding
@ ~/.julia/packages/Flux/3711C/src/layers/basic.jl:776 [inlined]
[4] (::Zygote.Pullback{…})(Δ::ChainRulesCore.InplaceableThunk{…})
@ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface2.jl:0
[5] FluxML/Flux.jl#197
@ ./REPL[408]:4 [inlined]
[6] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface2.jl:0
[7] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface.jl:97
[8] gradient(f::Function, args::Embedding{Matrix{Float32}})
@ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface.jl:154
[9] #gradient#1
@ ~/.julia/packages/Flux/3711C/src/gradient.jl:44 [inlined]
[10] gradient(f::Function, args::Embedding{Matrix{Float32}})
@ Flux ~/.julia/packages/Flux/3711C/src/gradient.jl:31
[11] top-level scope
@ REPL[408]:4
Some type information was truncated. Use `show(err)` to see complete types.
(@v1.11) pkg> st Flux Zygote
Status `~/.julia/environments/v1.11/Project.toml`
[587475ba] Flux v0.16.3
[e88e6eb3] Zygote v0.7.5
I presume the problem is Zygote 0.7 and thunks, as it works fine on earlier versions:
julia> let e = Embedding(2=>2)
x = Flux.onehotbatch([1 2; 2 1], 1:2)
# x = Array(x)
Flux.gradient(m -> sum(abs2, m(x)), e)
end
((weight = Float32[6.834647 3.3733022; 5.7237077 0.9229657],),)
julia> let e = Embedding(2=>2)
x = Flux.onehotbatch([1 2; 2 1], 1:2)
x = Array(x)
Flux.gradient(m -> sum(abs2, m(x)), e)
end
((weight = Float32[1.961737 -1.5491782; -0.6510874 11.824801],),)
(jl_ZbRV0D) pkg> st
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_ZbRV0D/Project.toml`
⌃ [587475ba] Flux v0.14.25
⌅ [e88e6eb3] Zygote v0.6.75
Edit, with Dense the problem is only with OneHotArray, and not with Array:
julia> let d = Dense(2=>2)
x = Flux.onehotbatch([1 2; 2 1], 1:2)
x = Array(x)
Flux.gradient(m -> sum(abs2, m(x)), d)
end
((weight = Float32[0.6652966 -3.0755887; 1.8529012 2.833063], bias = Float32[-2.4102921, 4.685964], σ = nothing),)
julia> let d = Dense(2=>2)
x = Flux.onehotbatch([1 2; 2 1], 1:2)
# x = Array(x)
Flux.gradient(m -> sum(abs2, m(x)), d)
end
ERROR: MethodError: no method matching reshape(::Nothing, ::Tuple{Int64, Int64, Int64})
The function `reshape` exists, but no method is defined for this combination of argument types.
Closest candidates are:
reshape(::ChainRulesCore.AbstractThunk, ::Any...)
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/U6wNx/src/tangent_types/thunks.jl:62
reshape(::Array{T, M}, ::NTuple{N, Int64}) where {T, N, M}
@ Base reshapedarray.jl:40
reshape(::BitArray{N}, ::NTuple{N, Int64}) where N
@ Base bitarray.jl:479
...
Stacktrace:
[1] (::Zygote.var"#617#621"{OneHotArrays.OneHotArray{UInt32, 2, 3, Matrix{UInt32}}, Tuple{Int64, Colon}})(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/aLHDR/src/lib/array.jl:107
[2] (::Zygote.var"#2783#back#623"{Zygote.var"#617#621"{…}})(Δ::ChainRulesCore.Thunk{ChainRules.var"#546#549"{…}})
@ Zygote ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72
[3] Dense
@ ~/.julia/packages/Flux/3711C/src/layers/basic.jl:204 [inlined]
[4] (::Zygote.Pullback{…})(Δ::ChainRulesCore.InplaceableThunk{…})
@ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface2.jl:0
[5] FluxML/Flux.jl#215
@ ./REPL[419]:4 [inlined]
[6] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface2.jl:0
MWE with only Zygote:
julia> using Zygote
julia> let x = rand(Bool, 12)
w = rand(Float32, 4, 3)
gradient(w -> sum(w * reshape(x, 3, 4)), w)
end
ERROR: MethodError: no method matching reshape(::Nothing, ::Tuple{Int64})
...
Stacktrace:
[1] (::Zygote.var"#617#621"{Vector{Bool}, Tuple{Int64, Int64}})(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/aLHDR/src/lib/array.jl:107
[2] (::Zygote.var"#2783#back#623"{Zygote.var"#617#621"{…}})(Δ::ChainRulesCore.Thunk{ChainRules.var"#546#549"{…}})
@ Zygote ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72
[3] FluxML/Flux.jl#43
@ ./REPL[30]:3 [inlined]
[4] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface2.jl:0
[5] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface.jl:97
[6] gradient(f::Function, args::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface.jl:154
[7] top-level scope
@ REPL[30]:3
Some type information was truncated. Use `show(err)` to see complete types.
(@v1.11) pkg> st Zygote
Status `~/.julia/environments/v1.11/Project.toml`
[e88e6eb3] Zygote v0.7.5
Fails on 0.7.0, but worked before:
julia> let x = rand(Bool, 12)
w = rand(Float32, 4, 3)
gradient(w -> sum(w * reshape(x, 3, 4)), w)
end
(Float32[2.0 2.0 2.0; 2.0 2.0 2.0; 2.0 2.0 2.0; 2.0 2.0 2.0],)
(jl_ZbRV0D) pkg> st Zygote
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_ZbRV0D/Project.toml`
⌅ [e88e6eb3] Zygote v0.6.75
The stacktrace points to this rule:
https://github.com/FluxML/Zygote.jl/blob/1b914d994aea236bcb6d3d0cd6c099d86cede101/src/lib/array.jl#L106-L107
PR #966 introduced @_adjoint_keepthunks, after which @adjoint is supposed not to keep them. The Thunk is indeed being converted to nothing, perhaps too late to prevent the backward function from being called at all?
The relevant code from https://github.com/FluxML/ZygoteRules.jl/pull/17 is these lines which produce:
back(::Nothing) = nothing
back(Δ) = $gradtuple(_back(unthunk_tangent(Δ))
instead of something like:
back(::Nothing) = nothing
back(Δ::AbstractThunk) = back(unthunk_tangent(Δ))
back(Δ) = $gradtuple(_back(Δ))