Zygote.jl
Zygote.jl copied to clipboard
gradient() fails on array mutation for `mean(f, x; dims)`
If you provide both an element-wise function f and a dimension specification, mean() apparently causes array mutation, which breaks Zygote's ability to differentiate:
julia> using Zygote, Statistics
x = randn(3, 3)
Zygote.gradient(Params([x])) do
sum(mean(abs2, x, dims=1))
end
ERROR: Mutating arrays is not supported -- called copyto!(::Matrix{Float64}, _...)
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] (::Zygote.var"#441#442"{Matrix{Float64}})(#unused#::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/lib/array.jl:74
[3] (::Zygote.var"#2330#back#443"{Zygote.var"#441#442"{Matrix{Float64}}})(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[4] Pullback
@ ./broadcast.jl:894 [inlined]
[5] Pullback
@ ./broadcast.jl:891 [inlined]
[6] Pullback
@ ./broadcast.jl:887 [inlined]
[7] (::typeof(∂(materialize!)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[8] Pullback
@ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Statistics/src/Statistics.jl:181 [inlined]
[9] (::typeof(∂(_mean)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[10] Pullback
@ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Statistics/src/Statistics.jl:104 [inlined]
[11] (::typeof(∂(#mean#1)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[12] Pullback
@ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Statistics/src/Statistics.jl:104 [inlined]
[13] (::typeof(∂(mean##kw)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[14] Pullback
@ ./REPL[14]:4 [inlined]
[15] (::typeof(∂(#17)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[16] (::Zygote.var"#89#90"{Params, typeof(∂(#17)), Zygote.Context})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:356
[17] gradient(f::Function, args::Params)
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:76
[18] top-level scope
@ REPL[14]:3
Looking through the adjoints for mean() defined in lib/array.jl, I would guess that the fact that I'm passing abs2 in for f causes Zygote's implementation to be skipped altogether, and then the dims kwarg causes us to go down a bad path that involves array mutation. I was going to submit a PR to create a new @adjoint definition for one that includes f, but I don't know how to get the adjoint of a user-provided function.
but I don't know how to get the adjoint of a user-provided function
You could use Zygote.pullback to AD through it, which will get an adjoint if it exists. A PR to ChainRules would be well received, see https://github.com/JuliaDiff/ChainRules.jl/issues/85
The low-tech way to implement this is to turn it into broadcasting, as is currently done for sum(::Function, ::CuArray) here:
https://github.com/FluxML/Zygote.jl/blob/master/src/lib/broadcast.jl#L280-L283
reopened as i had to revert the fix