ChainRules.jl
ChainRules.jl copied to clipboard
Add rules and tests for `kron`
In Julia 1.9 there was an internal change in kron that introduced some mutation, which has made Zygote unable to differentiate kron. Here, we add some rules to restore that ability.
Discovered in https://github.com/JuliaGaussianProcesses/TemporalGPs.jl/pull/115
made Zygote unable to differentiate
kron
Did the workaround in https://github.com/FluxML/Zygote.jl/pull/1378 not fix it? As mentioned in #684, should ideally be fixed in ChainRules nevertheless, but I'm a bit curious.
made Zygote unable to differentiate
kronDid the workaround in FluxML/Zygote.jl#1378 not fix it? As mentioned in #684, should ideally be fixed in ChainRules nevertheless, but I'm a bit curious.
Thanks for commenting. I think @willtebbutt said that he will have a look at these rules later on. I don't know if it is related, but the abovementioned fix predates Julia 1.9 by several months. I observed the breakage when upgrading Julia from 1.8 to 1.9.
Hi all, I rewrote the rules and now all the tests pass. There is probably opportunity to optimize them, please let me know.
Ok, did not test on Julia 1.6. Apparently this requires special care
Why don't we see the full stack traces here? Is it due to using JuliaInterpreter?
Ok, I made the suggested changes and added tests to check the correct behavior of the projections. However, we have some type inference problem in the matrix-matrix case.
The problem is this:
julia> x = Diagonal(rand(2)); y = Diagonal(rand(2)); z, pb = rrule(kron, x, y);
julia> @code_warntype unthunk(pb(z)[2])
MethodInstance for ChainRulesCore.unthunk(::Thunk{ChainRules.var"#2318#2321"{Base.ReshapedArray{Float64, 4, Diagonal{Float64, Vector{Float64}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Diagonal{Float64, Vector{Float64}}, ProjectTo{Diagonal, NamedTuple{(:diag,), Tuple{ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}}}}})
from unthunk(x::Thunk) @ ChainRulesCore ~/.julia/packages/ChainRulesCore/0t04l/src/tangent_types/thunks.jl:204
Arguments
#self#::Core.Const(ChainRulesCore.unthunk)
x::Thunk{ChainRules.var"#2318#2321"{Base.ReshapedArray{Float64, 4, Diagonal{Float64, Vector{Float64}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Diagonal{Float64, Vector{Float64}}, ProjectTo{Diagonal, NamedTuple{(:diag,), Tuple{ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}}}}}
Body::Any
1 ─ nothing
│ %2 = Base.getproperty(x, :f)::ChainRules.var"#2318#2321"{Base.ReshapedArray{Float64, 4, Diagonal{Float64, Vector{Float64}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Diagonal{Float64, Vector{Float64}}, ProjectTo{Diagonal, NamedTuple{(:diag,), Tuple{ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}}}}
│ %3 = (%2)()::Any
└── return %3
Any ideas how to make the unthunking type-stable here?
EDIT:
The core of the problem is that dot(y, first(eachslice(dz; dims = (2, 4)))) is type-unstable:
@code_warntype dot(y, first(eachslice(dz; dims = (2, 4))))
MethodInstance for LinearAlgebra.dot(::Diagonal{Float64, Vector{Float64}}, ::SubArray{Float64, 2, Base.ReshapedArray{Float64, 4, Diagonal{Float64, Vector{Float64}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64, Base.Slice{Base.OneTo{Int64}}, Int64}, false})
from dot(D::Diagonal, B::AbstractMatrix) @ LinearAlgebra ~/.julia/juliaup/julia-1.9.3+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/diagonal.jl:806
Arguments
#self#::Core.Const(LinearAlgebra.dot)
D::Diagonal{Float64, Vector{Float64}}
B::SubArray{Float64, 2, Base.ReshapedArray{Float64, 4, Diagonal{Float64, Vector{Float64}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64, Base.Slice{Base.OneTo{Int64}}, Int64}, false}
Body::Any
1 ─ %1 = LinearAlgebra.size(D)::Tuple{Int64, Int64}
│ %2 = LinearAlgebra.size(B)::Tuple{Int64, Int64}
│ %3 = (%1 == %2)::Bool
└── goto #3 if not %3
2 ─ goto #4
3 ─ %6 = LinearAlgebra.size(D)::Tuple{Int64, Int64}
│ %7 = LinearAlgebra.size(B)::Tuple{Int64, Int64}
│ %8 = Base.string("Matrix sizes ", %6, " and ", %7, " differ")::String
│ %9 = LinearAlgebra.DimensionMismatch(%8)::Any
└── LinearAlgebra.throw(%9)
4 ┄ %11 = Base.getproperty(D, :diag)::Vector{Float64}
│ %12 = LinearAlgebra.diagind(B)::Core.PartialStruct(StepRange{Int64, Int64}, Any[Core.Const(1), Int64, Int64])
│ %13 = LinearAlgebra.view(B, %12)::Core.PartialStruct(SubArray{Float64, 1, Base.ReshapedArray{Float64, 1, SubArray{Float64, 2, Base.ReshapedArray{Float64, 4, Diagonal{Float64, Vector{Float64}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64, Base.Slice{Base.OneTo{Int64}}, Int64}, false}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{StepRange{Int64, Int64}}, false}, Any[Base.ReshapedArray{Float64, 1, SubArray{Float64, 2, Base.ReshapedArray{Float64, 4, Diagonal{Float64, Vector{Float64}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64, Base.Slice{Base.OneTo{Int64}}, Int64}, false}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Core.PartialStruct(Tuple{StepRange{Int64, Int64}}, Any[Core.PartialStruct(StepRange{Int64, Int64}, Any[Core.Const(1), Int64, Int64])]), Core.Const(0), Core.Const(0)])
│ %14 = LinearAlgebra.dot(%11, %13)::Any
└── return %14
and I cannot fix that without collecting either y or dz. Any other ideas?