ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

Rule for `vect` unthunks many times

Open mcabbott opened this issue 3 years ago • 0 comments

This:

using ChainRulesCore
myplus(x,y) = x + y
function ChainRulesCore.rrule(::typeof(myplus), x, y)
  println("myplus rrule forward")
  x+y, dz -> begin
    println("myplus rrule reverse, dz isa ", typeof(dz))
    NoTangent(), @thunk(unthunk(dz)), @thunk @show unthunk(dz)
  end
end

using Yota
grad([1,2,3]) do x
  prod(myplus(x, [4,5,6]))
end

prints this:

myplus rrule forward
myplus rrule reverse, dz isa InplaceableThunk{Thunk{ChainRules.var"#1675#1678"{Float64, Colon, Vector{Float64}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Float64}}, ChainRules.var"#1674#1677"{Float64, Colon, Vector{Float64}, Float64}}
unthunk(dz) = [63.0, 45.0, 35.0]
unthunk(dz) = [63.0, 45.0, 35.0]  # from @show dz inside the @thunk, runs 3 times
unthunk(dz) = [63.0, 45.0, 35.0]
(315.0, (ZeroTangent(), [63.0, 45.0, 35.0]))

Almost all rules should call unthunk exactly once on their input, maybe others make the same mistake, e.g. hvcat?

mcabbott avatar Sep 19 '22 01:09 mcabbott