Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Gradient of `reshape(::Array{Bool}, ...)` does not handle thunks

Open mcabbott opened this issue 8 months ago • 1 comments

Originally:

julia> using Flux

julia> let e = Embedding(2=>2)
           x = Flux.onehotbatch([1 2; 2 1], 1:2)
           # x = Array(x)  # similar error with Array or OneHotArray
           Flux.gradient(m -> sum(abs2, m(x)), e)
       end
ERROR: MethodError: no method matching reshape(::Nothing, ::Tuple{Int64, Int64, Int64})

Edit, see Zygote-only MWE below

Closest candidates are:
  reshape(::ChainRulesCore.AbstractThunk, ::Any...)
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/U6wNx/src/tangent_types/thunks.jl:62
  reshape(::Array{T, M}, ::NTuple{N, Int64}) where {T, N, M}
   @ Base reshapedarray.jl:40
  reshape(::BitArray{N}, ::NTuple{N, Int64}) where N
   @ Base bitarray.jl:479
  ...

Stacktrace:
  [1] (::Zygote.var"#617#621"{OneHotArrays.OneHotArray{UInt32, 2, 3, Matrix{UInt32}}, Tuple{Int64, Colon}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/aLHDR/src/lib/array.jl:107
  [2] (::Zygote.var"#2783#back#623"{Zygote.var"#617#621"{…}})(Δ::ChainRulesCore.Thunk{ChainRules.var"#546#549"{…}})
    @ Zygote ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72
  [3] Embedding
    @ ~/.julia/packages/Flux/3711C/src/layers/basic.jl:776 [inlined]
  [4] (::Zygote.Pullback{…})(Δ::ChainRulesCore.InplaceableThunk{…})
    @ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface2.jl:0
  [5] FluxML/Flux.jl#197
    @ ./REPL[408]:4 [inlined]
  [6] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface2.jl:0
  [7] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface.jl:97
  [8] gradient(f::Function, args::Embedding{Matrix{Float32}})
    @ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface.jl:154
  [9] #gradient#1
    @ ~/.julia/packages/Flux/3711C/src/gradient.jl:44 [inlined]
 [10] gradient(f::Function, args::Embedding{Matrix{Float32}})
    @ Flux ~/.julia/packages/Flux/3711C/src/gradient.jl:31
 [11] top-level scope
    @ REPL[408]:4
Some type information was truncated. Use `show(err)` to see complete types.

(@v1.11) pkg> st Flux Zygote
Status `~/.julia/environments/v1.11/Project.toml`
  [587475ba] Flux v0.16.3
  [e88e6eb3] Zygote v0.7.5

I presume the problem is Zygote 0.7 and thunks, as it works fine on earlier versions:

julia> let e = Embedding(2=>2)
           x = Flux.onehotbatch([1 2; 2 1], 1:2)
           # x = Array(x)
           Flux.gradient(m -> sum(abs2, m(x)), e)
       end
((weight = Float32[6.834647 3.3733022; 5.7237077 0.9229657],),)

julia> let e = Embedding(2=>2)
           x = Flux.onehotbatch([1 2; 2 1], 1:2)
           x = Array(x)
           Flux.gradient(m -> sum(abs2, m(x)), e)
       end
((weight = Float32[1.961737 -1.5491782; -0.6510874 11.824801],),)

(jl_ZbRV0D) pkg> st
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_ZbRV0D/Project.toml`
⌃ [587475ba] Flux v0.14.25
⌅ [e88e6eb3] Zygote v0.6.75

Edit, with Dense the problem is only with OneHotArray, and not with Array:

julia> let d = Dense(2=>2)
           x = Flux.onehotbatch([1 2; 2 1], 1:2)
           x = Array(x)
           Flux.gradient(m -> sum(abs2, m(x)), d)
       end
((weight = Float32[0.6652966 -3.0755887; 1.8529012 2.833063], bias = Float32[-2.4102921, 4.685964], σ = nothing),)

julia> let d = Dense(2=>2)
           x = Flux.onehotbatch([1 2; 2 1], 1:2)
           # x = Array(x)
           Flux.gradient(m -> sum(abs2, m(x)), d)
       end
ERROR: MethodError: no method matching reshape(::Nothing, ::Tuple{Int64, Int64, Int64})
The function `reshape` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  reshape(::ChainRulesCore.AbstractThunk, ::Any...)
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/U6wNx/src/tangent_types/thunks.jl:62
  reshape(::Array{T, M}, ::NTuple{N, Int64}) where {T, N, M}
   @ Base reshapedarray.jl:40
  reshape(::BitArray{N}, ::NTuple{N, Int64}) where N
   @ Base bitarray.jl:479
  ...

Stacktrace:
  [1] (::Zygote.var"#617#621"{OneHotArrays.OneHotArray{UInt32, 2, 3, Matrix{UInt32}}, Tuple{Int64, Colon}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/aLHDR/src/lib/array.jl:107
  [2] (::Zygote.var"#2783#back#623"{Zygote.var"#617#621"{…}})(Δ::ChainRulesCore.Thunk{ChainRules.var"#546#549"{…}})
    @ Zygote ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72
  [3] Dense
    @ ~/.julia/packages/Flux/3711C/src/layers/basic.jl:204 [inlined]
  [4] (::Zygote.Pullback{…})(Δ::ChainRulesCore.InplaceableThunk{…})
    @ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface2.jl:0
  [5] FluxML/Flux.jl#215
    @ ./REPL[419]:4 [inlined]
  [6] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface2.jl:0

mcabbott avatar Apr 02 '25 19:04 mcabbott

MWE with only Zygote:

julia> using Zygote

julia> let x = rand(Bool, 12)
           w = rand(Float32, 4, 3)
           gradient(w -> sum(w * reshape(x, 3, 4)), w)
       end
ERROR: MethodError: no method matching reshape(::Nothing, ::Tuple{Int64})
  ...
Stacktrace:
 [1] (::Zygote.var"#617#621"{Vector{Bool}, Tuple{Int64, Int64}})(Δ::Nothing)
   @ Zygote ~/.julia/packages/Zygote/aLHDR/src/lib/array.jl:107
 [2] (::Zygote.var"#2783#back#623"{Zygote.var"#617#621"{…}})(Δ::ChainRulesCore.Thunk{ChainRules.var"#546#549"{…}})
   @ Zygote ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72
 [3] FluxML/Flux.jl#43
   @ ./REPL[30]:3 [inlined]
 [4] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
   @ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface2.jl:0
 [5] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float32)
   @ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface.jl:97
 [6] gradient(f::Function, args::Matrix{Float32})
   @ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface.jl:154
 [7] top-level scope
   @ REPL[30]:3
Some type information was truncated. Use `show(err)` to see complete types.

(@v1.11) pkg> st Zygote
Status `~/.julia/environments/v1.11/Project.toml`
  [e88e6eb3] Zygote v0.7.5

Fails on 0.7.0, but worked before:

julia> let x = rand(Bool, 12)
           w = rand(Float32, 4, 3)
           gradient(w -> sum(w * reshape(x, 3, 4)), w)
       end
(Float32[2.0 2.0 2.0; 2.0 2.0 2.0; 2.0 2.0 2.0; 2.0 2.0 2.0],)

(jl_ZbRV0D) pkg> st Zygote
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_ZbRV0D/Project.toml`
⌅ [e88e6eb3] Zygote v0.6.75

The stacktrace points to this rule:

https://github.com/FluxML/Zygote.jl/blob/1b914d994aea236bcb6d3d0cd6c099d86cede101/src/lib/array.jl#L106-L107

PR #966 introduced @_adjoint_keepthunks, after which @adjoint is supposed not to keep them. The Thunk is indeed being converted to nothing, perhaps too late to prevent the backward function from being called at all?

The relevant code from https://github.com/FluxML/ZygoteRules.jl/pull/17 is these lines which produce:

back(::Nothing) = nothing
back(Δ) = $gradtuple(_back(unthunk_tangent(Δ))

instead of something like:

back(::Nothing) = nothing
back(Δ::AbstractThunk) = back(unthunk_tangent(Δ))
back(Δ) = $gradtuple(_back(Δ))

mcabbott avatar Apr 03 '25 02:04 mcabbott