Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Zygote, a FillArray of structs and broadcasting don't work together

Open mohamed82008 opened this issue 4 years ago • 2 comments

Hi!

Here is a MWE that gives an error when trying to differentiate a function that broadcasts a FillArray of structs.

using FillArrays, Zygote

struct T
    a
end
f(t::T, x) = t.a + x
Zygote.gradient(rand(2)) do x
    ts = Fill(T(1), 2)
    sum(f.(ts, x))
end

The error is:

ERROR: Need an adjoint for constructor Fill{T,1,Tuple{Base.OneTo{Int64}}}. Gradient is of type Array{NamedTuple{(:a,),Tuple{Float64}},1}
Stacktrace:
 [1] error(::String) at .\error.jl:33
 [2] (::Zygote.Jnew{Fill{T,1,Tuple{Base.OneTo{Int64}}},Nothing,false})(::Array{NamedTuple{(:a,),Tuple{Float64}},1}) at C:\Users\user\.julia\packages\Zygote\jLxtV\src\lib\lib.jl:306
 [3] (::Zygote.var"#380#back#193"{Zygote.Jnew{Fill{T,1,Tuple{Base.OneTo{Int64}}},Nothing,false}})(::Array{NamedTuple{(:a,),Tuple{Float64}},1}) at C:\Users\user\.julia\dev\ZygoteRules\src\adjoint.jl:49
 [4] Fill at C:\Users\user\.julia\packages\FillArrays\OhEYG\src\FillArrays.jl:57 [inlined]
 [5] (::typeof(∂(Fill{T,1,Tuple{Base.OneTo{Int64}}})))(::Array{NamedTuple{(:a,),Tuple{Float64}},1}) at C:\Users\user\.julia\packages\Zygote\jLxtV\src\compiler\interface2.jl:0
 [6] Fill at C:\Users\user\.julia\packages\FillArrays\OhEYG\src\FillArrays.jl:64 [inlined]
 [7] Fill at C:\Users\user\.julia\packages\FillArrays\OhEYG\src\FillArrays.jl:69 [inlined]
 [8] Fill at C:\Users\user\.julia\packages\FillArrays\OhEYG\src\FillArrays.jl:76 [inlined]
 [9] (::typeof(∂(Fill)))(::Array{NamedTuple{(:a,),Tuple{Float64}},1}) at C:\Users\user\.julia\packages\Zygote\jLxtV\src\compiler\interface2.jl:0
 [10] #5 at .\REPL[8]:2 [inlined]
 [11] (::typeof(∂(#5)))(::Float64) at C:\Users\user\.julia\packages\Zygote\jLxtV\src\compiler\interface2.jl:0
 [12] (::Zygote.var"#36#37"{typeof(∂(#5))})(::Float64) at C:\Users\user\.julia\packages\Zygote\jLxtV\src\compiler\interface.jl:36
 [13] gradient(::Function, ::Array{Float64,1}) at C:\Users\user\.julia\packages\Zygote\jLxtV\src\compiler\interface.jl:45
 [14] top-level scope at REPL[8]:1

This is on Zygote 0.4.19.

mohamed82008 avatar May 01 '20 18:05 mohamed82008

Did you try defining an adjoint for the Fill constructor?

cossio avatar Jun 07 '20 21:06 cossio

Not fixed by https://github.com/JuliaArrays/FillArrays.jl/pull/153, FWIW.

mcabbott avatar Jul 22 '22 14:07 mcabbott