ComponentArrays.jl
ComponentArrays.jl copied to clipboard
Construction of ComponentArray inside of AD/Zygote
I want to compute the gradient of a loss function with respect to a ComponentArray
. In the loss function, I need to reconstruct a ComponentArray
.
Based on @jonniedie reply https://github.com/jonniedie/ComponentArrays.jl/issues/126#issuecomment-1141580528, I tried
function my_sum(v)
ax = getaxes(v)
@unpack x, y = v
ca = ComponentArray([x..., y...], ax)
return sum(ca.x + ca.y)
end
Zygote.gradient(my_sum, ComponentArray(x=[0.0], y=[0.0]))
which fails with
ERROR: ArgumentError: indexed assignment with a single value to possibly many locations is not supported; perhaps use broadcasting `.=` instead?
Stacktrace:
[1] setindex_shape_check(::ChainRulesCore.Tangent{Any, Tuple{Float64}}, ::Int64)
@ Base ./indices.jl:261
[2] _unsafe_setindex!(#unused#::IndexLinear, A::Vector{Float64}, x::ChainRulesCore.Tangent{Any, Tuple{Float64}}, I::UnitRange{Int64})
@ Base ./multidimensional.jl:939
[3] _setindex!
@ ./multidimensional.jl:930 [inlined]
[4] setindex!
@ ./abstractarray.jl:1344 [inlined]
[5] macro expansion
@ ~/.julia/packages/ComponentArrays/EjZNJ/src/array_interface.jl:0 [inlined]
[6] _setindex!(x::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(x = 1:1, y = 2:2)}}}, v::ChainRulesCore.Tangent{Any, Tuple{Float64}}, idx::Val{:y})
@ ComponentArrays ~/.julia/packages/ComponentArrays/EjZNJ/src/array_interface.jl:129
[7] setproperty!
@ ~/.julia/packages/ComponentArrays/EjZNJ/src/namedtuple_interface.jl:17 [inlined]
[8] (::ComponentArrays.var"#getproperty_adjoint#87"{ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(x = 1:1, y = 2:2)}}}, Symbol})(Δ::ChainRulesCore.Tangent{Any, Tuple{Float64}})
@ ComponentArrays ~/.julia/packages/ComponentArrays/EjZNJ/src/compat/chainrulescore.jl:4
[9] ZBack
@ ~/.julia/packages/Zygote/PD12J/src/compiler/chainrules.jl:206 [inlined]
[10] Pullback
@ ~/.julia/packages/UnPack/EkESO/src/UnPack.jl:34 [inlined]
[11] (::typeof(∂(unpack)))(Δ::Tuple{Float64})
@ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface2.jl:0
[12] macro expansion
@ ~/.julia/packages/UnPack/EkESO/src/UnPack.jl:101 [inlined]
[13] Pullback
pointing to the @unpack
call. @avik-pal noted that it also happens even without the @unpack
function my_sum(v)
ax = getaxes(v)
ca = ComponentArray([v.x..., v.y...], ax)
return sum(ca.x + ca.y)
end
Zygote.gradient(my_sum, ComponentArray(x=[0.0], y=[0.0]))
but is resolved by using vcat
function my_sum(v)
ax = getaxes(v)
@unpack x, y = v
ca = ComponentArray(vcat(x,y), ax)
return sum(ca.x + ca.y)
end
The issue seems to be that \Delta is a Tuple{Float64} in https://github.com/jonniedie/ComponentArrays.jl/blob/cbb24ef7156d18f1576ea48d7ae42023cc5bfa70/src/compat/chainrulescore.jl#L4 for splatting.
Constructing an array in general fails with Zygote:
using Zygote
using ComponentArrays
Zygote.gradient(x -> ComponentArray(a = [5])[1], [0.])
gives
ERROR: Mutating arrays is not supported -- called push!(Vector{Any}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)
Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
https://fluxml.ai/Zygote.jl/latest/limitations
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] _throw_mutation_error(f::Function, args::Vector{Any})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/lib/array.jl:86
[3] (::Zygote.var"#397#398"{Vector{Any}})(#unused#::Nothing)
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/lib/array.jl:105
[4] (::Zygote.var"#2508#back#399"{Zygote.var"#397#398"{Vector{Any}}})(Δ::Nothing)
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[5] Pullback
@ ./namedtuple.jl:309 [inlined]
[6] (::typeof(∂(merge)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[7] Pullback
@ ~/.julia/packages/ComponentArrays/YyD7i/src/componentarray.jl:161 [inlined]
[8] (::typeof(∂(make_idx)))(Δ::Tuple{Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[9] Pullback
@ ~/.julia/packages/ComponentArrays/YyD7i/src/componentarray.jl:147 [inlined]
[10] (::typeof(∂(make_carray_args)))(Δ::Tuple{Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[11] Pullback
@ ~/.julia/packages/ComponentArrays/YyD7i/src/componentarray.jl:139 [inlined]
[12] (::typeof(∂(make_carray_args)))(Δ::Tuple{Vector{Float64}, Nothing})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[13] Pullback
@ ~/.julia/packages/ComponentArrays/YyD7i/src/componentarray.jl:63 [inlined]
[14] Pullback
@ ~/.julia/packages/ComponentArrays/YyD7i/src/componentarray.jl:66 [inlined]
[15] (::typeof(∂(#ComponentArray#21)))(Δ::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(a = 1:1,)}}})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[16] Pullback
@ ~/.julia/packages/ComponentArrays/YyD7i/src/componentarray.jl:66 [inlined]
[17] (::typeof(∂(Type##kw)))(Δ::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(a = 1:1,)}}})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[18] Pullback
@ ./REPL[4]:1 [inlined]
[19] (::typeof(∂(#3)))(Δ::Int64)
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[20] (::Zygote.var"#60#61"{typeof(∂(#3))})(Δ::Int64)
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:45
[21] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:97
[22] top-level scope
@ REPL[4]:1
It seems that to make CA work with Zygote it must entirely avoid mutating arrays (even appending to arrays)...
Is there any chance to resolve the above error by using Zygote.Buffer
as illustrated in here?