Zygote.jl
Zygote.jl copied to clipboard
Zygote, a FillArray of structs and broadcasting don't work together
Hi!
Here is a MWE that gives an error when trying to differentiate a function that broadcasts a FillArray of structs.
using FillArrays, Zygote
struct T
a
end
f(t::T, x) = t.a + x
Zygote.gradient(rand(2)) do x
ts = Fill(T(1), 2)
sum(f.(ts, x))
end
The error is:
ERROR: Need an adjoint for constructor Fill{T,1,Tuple{Base.OneTo{Int64}}}. Gradient is of type Array{NamedTuple{(:a,),Tuple{Float64}},1}
Stacktrace:
[1] error(::String) at .\error.jl:33
[2] (::Zygote.Jnew{Fill{T,1,Tuple{Base.OneTo{Int64}}},Nothing,false})(::Array{NamedTuple{(:a,),Tuple{Float64}},1}) at C:\Users\user\.julia\packages\Zygote\jLxtV\src\lib\lib.jl:306
[3] (::Zygote.var"#380#back#193"{Zygote.Jnew{Fill{T,1,Tuple{Base.OneTo{Int64}}},Nothing,false}})(::Array{NamedTuple{(:a,),Tuple{Float64}},1}) at C:\Users\user\.julia\dev\ZygoteRules\src\adjoint.jl:49
[4] Fill at C:\Users\user\.julia\packages\FillArrays\OhEYG\src\FillArrays.jl:57 [inlined]
[5] (::typeof(∂(Fill{T,1,Tuple{Base.OneTo{Int64}}})))(::Array{NamedTuple{(:a,),Tuple{Float64}},1}) at C:\Users\user\.julia\packages\Zygote\jLxtV\src\compiler\interface2.jl:0
[6] Fill at C:\Users\user\.julia\packages\FillArrays\OhEYG\src\FillArrays.jl:64 [inlined]
[7] Fill at C:\Users\user\.julia\packages\FillArrays\OhEYG\src\FillArrays.jl:69 [inlined]
[8] Fill at C:\Users\user\.julia\packages\FillArrays\OhEYG\src\FillArrays.jl:76 [inlined]
[9] (::typeof(∂(Fill)))(::Array{NamedTuple{(:a,),Tuple{Float64}},1}) at C:\Users\user\.julia\packages\Zygote\jLxtV\src\compiler\interface2.jl:0
[10] #5 at .\REPL[8]:2 [inlined]
[11] (::typeof(∂(#5)))(::Float64) at C:\Users\user\.julia\packages\Zygote\jLxtV\src\compiler\interface2.jl:0
[12] (::Zygote.var"#36#37"{typeof(∂(#5))})(::Float64) at C:\Users\user\.julia\packages\Zygote\jLxtV\src\compiler\interface.jl:36
[13] gradient(::Function, ::Array{Float64,1}) at C:\Users\user\.julia\packages\Zygote\jLxtV\src\compiler\interface.jl:45
[14] top-level scope at REPL[8]:1
This is on Zygote 0.4.19.
Did you try defining an adjoint for the Fill
constructor?
Not fixed by https://github.com/JuliaArrays/FillArrays.jl/pull/153, FWIW.