ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

make 3-arg dot rrule partially lazy

Open mohamed82008 opened this issue 1 year ago • 8 comments

This addresses #788. I had to remove the projection to make it work otherwise I get the following error due to a missing projection method. Projecting the lazy array to a dense array when A is dense partially defeats the purpose of this PR so I am leaving it up to the review process to decide what to do here. I can define a projection method if that's preferred.

julia> show(err)
1-element ExceptionStack:
MethodError: no method matching (::ChainRulesCore.ProjectTo{AbstractArray, @NamedTuple{element::ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, axes::Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}})(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(*), Tuple{Float64, Vector{Float64}}}, Adjoint{Float64, Vector{Float64}}}})

Closest candidates are:
  (::ChainRulesCore.ProjectTo{T})(::ChainRulesCore.NotImplemented) where T
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/zgT0R/src/projection.jl:121
  (::ChainRulesCore.ProjectTo)(::ChainRulesCore.Thunk)
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/zgT0R/src/projection.jl:124
  (::ChainRulesCore.ProjectTo{AbstractArray})(::Number)
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/zgT0R/src/projection.jl:253
  ...

Stacktrace:
 [1] (::ChainRules.var"#1966#1970"{Adjoint{Float64, Vector{Float64}}, Float64, Vector{Float64}, ChainRulesCore.ProjectTo{AbstractArray, @NamedTuple{element::ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, axes::Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}})()
   @ ChainRules ~/.julia/dev/ChainRules/src/rulesets/LinearAlgebra/dense.jl:39
 [2] unthunk
   @ ~/.julia/packages/ChainRulesCore/zgT0R/src/tangent_types/thunks.jl:204 [inlined]
 [3] wrap_chainrules_output
   @ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:110 [inlined]
 [4] map
   @ ./tuple.jl:293 [inlined]
 [5] map
   @ ./tuple.jl:294 [inlined]
 [6] wrap_chainrules_output
   @ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:111 [inlined]
 [7] ZBack
   @ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:211 [inlined]
 [8] (::Zygote.var"#75#76"{Zygote.ZBack{ChainRules.var"#dot_pullback#1968"{Vector{Float64}, Matrix{Float64}, Vector{Float64}, Vector{Float64}, ChainRulesCore.ProjectTo{AbstractArray, @NamedTuple{element::ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, axes::Tuple{Base.OneTo{Int64}}}}, ChainRulesCore.ProjectTo{AbstractArray, @NamedTuple{element::ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, axes::Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}, ChainRulesCore.ProjectTo{AbstractArray, @NamedTuple{element::ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, axes::Tuple{Base.OneTo{Int64}}}}}}})(Δ::Float64)
   @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91
 [9] top-level scope
   @ REPL[6]:1

This is related to the discussion in https://github.com/FluxML/Zygote.jl/issues/1507.

mohamed82008 avatar May 25 '24 21:05 mohamed82008