ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

Add rules and tests for `kron`

Open simsurace opened this issue 2 years ago • 7 comments

In Julia 1.9 there was an internal change in kron that introduced some mutation, which has made Zygote unable to differentiate kron. Here, we add some rules to restore that ability.

Discovered in https://github.com/JuliaGaussianProcesses/TemporalGPs.jl/pull/115

simsurace avatar Sep 25 '23 13:09 simsurace

made Zygote unable to differentiate kron

Did the workaround in https://github.com/FluxML/Zygote.jl/pull/1378 not fix it? As mentioned in #684, should ideally be fixed in ChainRules nevertheless, but I'm a bit curious.

devmotion avatar Sep 25 '23 13:09 devmotion

made Zygote unable to differentiate kron

Did the workaround in FluxML/Zygote.jl#1378 not fix it? As mentioned in #684, should ideally be fixed in ChainRules nevertheless, but I'm a bit curious.

Thanks for commenting. I think @willtebbutt said that he will have a look at these rules later on. I don't know if it is related, but the abovementioned fix predates Julia 1.9 by several months. I observed the breakage when upgrading Julia from 1.8 to 1.9.

simsurace avatar Sep 25 '23 14:09 simsurace

Hi all, I rewrote the rules and now all the tests pass. There is probably opportunity to optimize them, please let me know.

simsurace avatar Sep 26 '23 19:09 simsurace

Ok, did not test on Julia 1.6. Apparently this requires special care

simsurace avatar Sep 27 '23 08:09 simsurace

Why don't we see the full stack traces here? Is it due to using JuliaInterpreter?

simsurace avatar Sep 27 '23 10:09 simsurace

Ok, I made the suggested changes and added tests to check the correct behavior of the projections. However, we have some type inference problem in the matrix-matrix case.

simsurace avatar Sep 28 '23 10:09 simsurace

The problem is this:

julia> x = Diagonal(rand(2)); y = Diagonal(rand(2)); z, pb = rrule(kron, x, y);

julia> @code_warntype unthunk(pb(z)[2])
MethodInstance for ChainRulesCore.unthunk(::Thunk{ChainRules.var"#2318#2321"{Base.ReshapedArray{Float64, 4, Diagonal{Float64, Vector{Float64}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Diagonal{Float64, Vector{Float64}}, ProjectTo{Diagonal, NamedTuple{(:diag,), Tuple{ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}}}}})
  from unthunk(x::Thunk) @ ChainRulesCore ~/.julia/packages/ChainRulesCore/0t04l/src/tangent_types/thunks.jl:204
Arguments
  #self#::Core.Const(ChainRulesCore.unthunk)
  x::Thunk{ChainRules.var"#2318#2321"{Base.ReshapedArray{Float64, 4, Diagonal{Float64, Vector{Float64}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Diagonal{Float64, Vector{Float64}}, ProjectTo{Diagonal, NamedTuple{(:diag,), Tuple{ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}}}}}
Body::Any
1 ─      nothing
│   %2 = Base.getproperty(x, :f)::ChainRules.var"#2318#2321"{Base.ReshapedArray{Float64, 4, Diagonal{Float64, Vector{Float64}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Diagonal{Float64, Vector{Float64}}, ProjectTo{Diagonal, NamedTuple{(:diag,), Tuple{ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}}}}
│   %3 = (%2)()::Any
└──      return %3

Any ideas how to make the unthunking type-stable here? EDIT: The core of the problem is that dot(y, first(eachslice(dz; dims = (2, 4)))) is type-unstable:

@code_warntype dot(y, first(eachslice(dz; dims = (2, 4))))
MethodInstance for LinearAlgebra.dot(::Diagonal{Float64, Vector{Float64}}, ::SubArray{Float64, 2, Base.ReshapedArray{Float64, 4, Diagonal{Float64, Vector{Float64}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64, Base.Slice{Base.OneTo{Int64}}, Int64}, false})
  from dot(D::Diagonal, B::AbstractMatrix) @ LinearAlgebra ~/.julia/juliaup/julia-1.9.3+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/diagonal.jl:806
Arguments
  #self#::Core.Const(LinearAlgebra.dot)
  D::Diagonal{Float64, Vector{Float64}}
  B::SubArray{Float64, 2, Base.ReshapedArray{Float64, 4, Diagonal{Float64, Vector{Float64}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64, Base.Slice{Base.OneTo{Int64}}, Int64}, false}
Body::Any
1 ─ %1  = LinearAlgebra.size(D)::Tuple{Int64, Int64}
│   %2  = LinearAlgebra.size(B)::Tuple{Int64, Int64}
│   %3  = (%1 == %2)::Bool
└──       goto #3 if not %3
2 ─       goto #4
3 ─ %6  = LinearAlgebra.size(D)::Tuple{Int64, Int64}
│   %7  = LinearAlgebra.size(B)::Tuple{Int64, Int64}
│   %8  = Base.string("Matrix sizes ", %6, " and ", %7, " differ")::String
│   %9  = LinearAlgebra.DimensionMismatch(%8)::Any
└──       LinearAlgebra.throw(%9)
4 ┄ %11 = Base.getproperty(D, :diag)::Vector{Float64}
│   %12 = LinearAlgebra.diagind(B)::Core.PartialStruct(StepRange{Int64, Int64}, Any[Core.Const(1), Int64, Int64])
│   %13 = LinearAlgebra.view(B, %12)::Core.PartialStruct(SubArray{Float64, 1, Base.ReshapedArray{Float64, 1, SubArray{Float64, 2, Base.ReshapedArray{Float64, 4, Diagonal{Float64, Vector{Float64}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64, Base.Slice{Base.OneTo{Int64}}, Int64}, false}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{StepRange{Int64, Int64}}, false}, Any[Base.ReshapedArray{Float64, 1, SubArray{Float64, 2, Base.ReshapedArray{Float64, 4, Diagonal{Float64, Vector{Float64}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64, Base.Slice{Base.OneTo{Int64}}, Int64}, false}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Core.PartialStruct(Tuple{StepRange{Int64, Int64}}, Any[Core.PartialStruct(StepRange{Int64, Int64}, Any[Core.Const(1), Int64, Int64])]), Core.Const(0), Core.Const(0)])
│   %14 = LinearAlgebra.dot(%11, %13)::Any
└──       return %14

and I cannot fix that without collecting either y or dz. Any other ideas?

simsurace avatar Sep 28 '23 12:09 simsurace