ChainRulesCore.jl
ChainRulesCore.jl copied to clipboard
thunk still runs for non Flux.params which leads to unnecessary computation
Hello! I have a minimum example here about Flux and ChainRulesCore
using Flux, ChainRulesCore
import ChainRulesCore.rrule
function f(a::Float32, b::Float32)
return a * b
end
function rrule(::typeof(f), a::Float32, b::Float32)
println("rrule is called")
y = f(a,b)
function pullback(Δy)
da = @thunk(∇a(a,b,Δy))
db = @thunk(∇b(a,b,Δy))
return (NoTangent(), da, db)
end
return y, pullback
end
function ∇a(a,b,Δy)
println("∇a is called")
return b * Δy
end
function ∇b(a,b,Δy)
println("∇b is called")
return a * Δy
end
a = 1f0
b = 2f0
ga = gradient(()->f(a,b), Flux.params(a))
which defines my custom function f
(as a multiplication of 2 scalars) and defines the rrule
from ChainRulesCore. In the last line, when I compute gradient w.r.t. variable a
only, as ga = gradient(()->f(a,b), Flux.params(a))
, I expect to only see ∇a
being called but actually I see both of ∇a
and ∇b
being called in the log
rrule is called
∇a is called
∇b is called
any idea why? This could be problematic when f
is complicated function and it is unnecessary to call ∇b
if time-consuming. Thanks for any help!
Flux.jl uses Zygote.jl, and Zygote.jl doesn't yet utilise ChainRulesCore.jl's Thunk
s
https://github.com/FluxML/Zygote.jl/blob/9602c6b2038879034c2de14d1f4aa251d99c6ea4/src/compiler/chainrules.jl#L104
There is a WIP PR to make Zygote.jl utilise Thunk
s here: https://github.com/FluxML/Zygote.jl/pull/966
Thanks for your quick reply. Looking forward to the PR being merged