ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

`rrule` for `cov`

Open mzgubic opened this issue 3 years ago • 1 comments

We are missing rules for cov, which makes Zygote sad because cov internally mutates arrays.

julia> using Zygote

julia> y = rand(2, 50)
2×50 Matrix{Float64}:
 0.400221  0.162725  0.16816  0.989187  0.893416  0.314517  0.103545  0.53503   0.936069  …  0.0113788  0.921907  0.186928  0.227917  0.402403  0.731638  0.610708  0.0180243  0.7481
 0.590605  0.585147  0.64363  0.714443  0.979011  0.338951  0.776283  0.603568  0.741731     0.715334   0.110166  0.513687  0.251841  0.215048  0.161864  0.11849   0.242418   0.341733

julia> gradient(m -> sum(cov(m)), y)
ERROR: Mutating arrays is not supported -- called copyto!(::Matrix{Float64}, _...)
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.var"#441#442"{Matrix{Float64}})(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/lib/array.jl:74
  [3] (::Zygote.var"#2347#back#443"{Zygote.var"#441#442"{Matrix{Float64}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [4] Pullback
    @ ./broadcast.jl:871 [inlined]
  [5] Pullback
    @ ./broadcast.jl:868 [inlined]
  [6] Pullback
    @ ./broadcast.jl:864 [inlined]
  [7] Pullback
    @ /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/Statistics/src/Statistics.jl:542 [inlined]
  [8] (::typeof(∂(#covzm#24)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
  [9] Pullback
    @ /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/Statistics/src/Statistics.jl:538 [inlined]
 [10] (::typeof(∂(covzm##kw)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [11] Pullback
    @ /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/Statistics/src/Statistics.jl:561 [inlined]
 [12] (::typeof(∂(#covm#30)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [13] Pullback
    @ /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/Statistics/src/Statistics.jl:561 [inlined]
 [14] (::typeof(∂(covm##kw)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [15] Pullback
    @ /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/Statistics/src/Statistics.jl:584 [inlined]
 [16] (::typeof(∂(#cov#38)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [17] Pullback
    @ /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/Statistics/src/Statistics.jl:584 [inlined]
 [18] Pullback
    @ ./REPL[70]:1 [inlined]
 [19] (::Zygote.var"#52#53"{typeof(∂(#43))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:41
 [20] gradient(f::Function, args::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:76
 [21] top-level scope
    @ REPL[70]:1

mzgubic avatar Jun 20 '22 09:06 mzgubic

Here's a start on implementing this: https://gist.github.com/mcabbott/8e0f1073271176291d16e9d18166a5e0

mcabbott avatar Sep 16 '22 18:09 mcabbott