StructArrays.jl icon indicating copy to clipboard operation
StructArrays.jl copied to clipboard

Zygote and StructArrays do not play nicely

Open ptiede opened this issue 3 years ago • 0 comments

Hi All,

I noticed during some of my development that StructArrays and Zygote seem to be broken. It seems that if you access a property of the struct array in a function Zygote/ChainRules doesn't maintain the StructArray type and this causes an issue during gradient accumulation. A MWE is

using Zygote
using StructArrays

f(p) = p.U^2 + p.V^2
l1(x) = sum(f, x) + sum(x.U)
l2(x) = sum(f.(x)  + x.U)
l3(x) = sum(f.(x) .+ x.U)


x = StructArray{NamedTuple{(:U,:V)}}((U=rand(10), V=rand(10)))

Zygote.gradient(l1, x) 
ERROR: MethodError: no method matching +(::NamedTuple{(:components,), Tuple{NamedTuple{(:U, :V), Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Nothing}}}}, ::Vector{NamedTuple{(:U, :V), Tuple{Float64, Float64}}})
Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...) at operators.jl:591
  +(::ChainRulesCore.AbstractThunk, ::Any) at ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_arithmetic.jl:122
  +(::Array, ::Array...) at arraymath.jl:12
  ...
Stacktrace:
  [1] accum(x::NamedTuple{(:components,), Tuple{NamedTuple{(:U, :V), Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Nothing}}}}, y::Vector{NamedTuple{(:U, :V), Tuple{Float64, Float64}}})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/lib/lib.jl:17
  [2] collect_similar
    @ ./array.jl:716 [inlined]
  [3] map
    @ ./abstractarray.jl:2933 [inlined]
  [4] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:122 [inlined]
  [5] map
    @ ./tuple.jl:223 [inlined]
  [6] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:106 [inlined]
  [7] ZBack
    @ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:206 [inlined]
  [8] Pullback
    @ ~/struct_array_issue.jl:5 [inlined]
  [9] (::typeof(∂(l0)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [10] (::Zygote.var"#60#61"{typeof(∂(l0))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:45
 [11] gradient(f::Function, args::StructVector{NamedTuple{(:U, :V)}, NamedTuple{(:U, :V), Tuple{Vector{Float64}, Vector{Float64}}}, Int64})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:97
 [12] top-level scope
    @ REPL[50]:1

#####################################################################
Zygote.gradient(l2, x)
ERROR: MethodError: no method matching +(::NamedTuple{(:components,), Tuple{NamedTuple{(:U, :V), Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Nothing}}}}, ::Vector{NamedTuple{(:U, :V), Tuple{Float64, Float64}}})
Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...) at operators.jl:591
  +(::ChainRulesCore.AbstractThunk, ::Any) at ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_arithmetic.jl:122
  +(::Array, ::Array...) at arraymath.jl:12
  ...
Stacktrace:
 [1] accum(x::NamedTuple{(:components,), Tuple{NamedTuple{(:U, :V), Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Nothing}}}}, y::Vector{NamedTuple{(:U, :V), Tuple{Float64, Float64}}})
   @ Zygote ~/.julia/packages/Zygote/SmJK6/src/lib/lib.jl:17
 [2] Pullback
   @ ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:43 [inlined]
 [3] Pullback
   @ ~/struct_array_issue.jl:6 [inlined]
 [4] (::typeof(∂(l1)))(Δ::Float64)
   @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [5] (::Zygote.var"#60#61"{typeof(∂(l1))})(Δ::Float64)
   @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:45
 [6] gradient(f::Function, args::StructVector{NamedTuple{(:U, :V)}, NamedTuple{(:U, :V), Tuple{Vector{Float64}, Vector{Float64}}}, Int64})
   @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:97
 [7] top-level scope
   @ REPL[51]:1


##########################################################################
Zygote.gradient(l3, x)
ERROR: MethodError: no method matching +(::NamedTuple{(:components,), Tuple{NamedTuple{(:U, :V), Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Nothing}}}}, ::Vector{NamedTuple{(:U, :V), Tuple{Float64, Float64}}})
Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...) at operators.jl:591
  +(::ChainRulesCore.AbstractThunk, ::Any) at ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_arithmetic.jl:122
  +(::Array, ::Array...) at arraymath.jl:12
  ...
Stacktrace:
 [1] accum(x::NamedTuple{(:components,), Tuple{NamedTuple{(:U, :V), Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Nothing}}}}, y::Vector{NamedTuple{(:U, :V), Tuple{Float64, Float64}}})
   @ Zygote ~/.julia/packages/Zygote/SmJK6/src/lib/lib.jl:17
 [2] Pullback
   @ ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:43 [inlined]
 [3] Pullback
   @ ~/struct_array_issue.jl:7 [inlined]
 [4] (::typeof(∂(l2)))(Δ::Float64)
   @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [5] (::Zygote.var"#60#61"{typeof(∂(l2))})(Δ::Float64)
   @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:45
 [6] gradient(f::Function, args::StructVector{NamedTuple{(:U, :V)}, NamedTuple{(:U, :V), Tuple{Vector{Float64}, Vector{Float64}}}, Int64})
   @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:97
 [7] top-level scope
   @ REPL[52]:1

On the other hand

l0(x) = sum(f, x)

Zygote.gradient(l0, x)

Seems to work fine and return a StructArray.

I have been playing with ChainRulesCore and ProjectTo to see if I could get this to work but I am not sure the best way to store everything internally.

Working environment

julia> Pkg.status() Status /tmp/jl_sssQXD/Project.toml [09ab397b] StructArrays v0.6.13 https://github.com/JuliaArrays/StructArrays.jl.git#master [e88e6eb3] Zygote v0.6.51 [09ab397b] StructArrays v0.6.13 https://github.com/JuliaArrays/StructArrays.jl.git#master [e88e6eb3] Zygote v0.6.51

julia> versioninfo() Julia Version 1.8.3 Commit 0434deb161e (2022-11-14 20:14 UTC) Platform Info: OS: Linux (x86_64-linux-gnu) CPU: 32 × AMD Ryzen 9 7950X 16-Core Processor WORD_SIZE: 64 LIBM: libopenlibm LLVM: libLLVM-13.0.1 (ORCJIT, znver3) Threads: 1 on 32 virtual cores Environment: JULIA_EDITOR = code JULIA_NUM_THREADS = 1

ptiede avatar Dec 08 '22 19:12 ptiede