ChainRules.jl
ChainRules.jl copied to clipboard
make 3-arg dot rrule partially lazy
This addresses #788. I had to remove the projection to make it work otherwise I get the following error due to a missing projection method. Projecting the lazy array to a dense array when A is dense partially defeats the purpose of this PR so I am leaving it up to the review process to decide what to do here. I can define a projection method if that's preferred.
julia> show(err)
1-element ExceptionStack:
MethodError: no method matching (::ChainRulesCore.ProjectTo{AbstractArray, @NamedTuple{element::ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, axes::Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}})(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(*), Tuple{Float64, Vector{Float64}}}, Adjoint{Float64, Vector{Float64}}}})
Closest candidates are:
(::ChainRulesCore.ProjectTo{T})(::ChainRulesCore.NotImplemented) where T
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/zgT0R/src/projection.jl:121
(::ChainRulesCore.ProjectTo)(::ChainRulesCore.Thunk)
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/zgT0R/src/projection.jl:124
(::ChainRulesCore.ProjectTo{AbstractArray})(::Number)
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/zgT0R/src/projection.jl:253
...
Stacktrace:
[1] (::ChainRules.var"#1966#1970"{Adjoint{Float64, Vector{Float64}}, Float64, Vector{Float64}, ChainRulesCore.ProjectTo{AbstractArray, @NamedTuple{element::ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, axes::Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}})()
@ ChainRules ~/.julia/dev/ChainRules/src/rulesets/LinearAlgebra/dense.jl:39
[2] unthunk
@ ~/.julia/packages/ChainRulesCore/zgT0R/src/tangent_types/thunks.jl:204 [inlined]
[3] wrap_chainrules_output
@ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:110 [inlined]
[4] map
@ ./tuple.jl:293 [inlined]
[5] map
@ ./tuple.jl:294 [inlined]
[6] wrap_chainrules_output
@ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:111 [inlined]
[7] ZBack
@ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:211 [inlined]
[8] (::Zygote.var"#75#76"{Zygote.ZBack{ChainRules.var"#dot_pullback#1968"{Vector{Float64}, Matrix{Float64}, Vector{Float64}, Vector{Float64}, ChainRulesCore.ProjectTo{AbstractArray, @NamedTuple{element::ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, axes::Tuple{Base.OneTo{Int64}}}}, ChainRulesCore.ProjectTo{AbstractArray, @NamedTuple{element::ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, axes::Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}, ChainRulesCore.ProjectTo{AbstractArray, @NamedTuple{element::ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, axes::Tuple{Base.OneTo{Int64}}}}}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91
[9] top-level scope
@ REPL[6]:1
This is related to the discussion in https://github.com/FluxML/Zygote.jl/issues/1507.