ChainRules.jl
ChainRules.jl copied to clipboard
Rule for `vect` unthunks many times
This:
using ChainRulesCore
myplus(x,y) = x + y
function ChainRulesCore.rrule(::typeof(myplus), x, y)
println("myplus rrule forward")
x+y, dz -> begin
println("myplus rrule reverse, dz isa ", typeof(dz))
NoTangent(), @thunk(unthunk(dz)), @thunk @show unthunk(dz)
end
end
using Yota
grad([1,2,3]) do x
prod(myplus(x, [4,5,6]))
end
prints this:
myplus rrule forward
myplus rrule reverse, dz isa InplaceableThunk{Thunk{ChainRules.var"#1675#1678"{Float64, Colon, Vector{Float64}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Float64}}, ChainRules.var"#1674#1677"{Float64, Colon, Vector{Float64}, Float64}}
unthunk(dz) = [63.0, 45.0, 35.0]
unthunk(dz) = [63.0, 45.0, 35.0] # from @show dz inside the @thunk, runs 3 times
unthunk(dz) = [63.0, 45.0, 35.0]
(315.0, (ZeroTangent(), [63.0, 45.0, 35.0]))
Almost all rules should call unthunk exactly once on their input, maybe others make the same mistake, e.g. hvcat?