ChainRulesCore.jl icon indicating copy to clipboard operation
ChainRulesCore.jl copied to clipboard

thunk still runs for non Flux.params which leads to unnecessary computation

Open ziyiyin97 opened this issue 2 years ago • 2 comments

Hello! I have a minimum example here about Flux and ChainRulesCore

using Flux, ChainRulesCore
import ChainRulesCore.rrule

function f(a::Float32, b::Float32)
    return a * b
end

function rrule(::typeof(f), a::Float32, b::Float32)
    println("rrule is called")
    y = f(a,b)
    function pullback(Δy)
        da = @thunk(∇a(a,b,Δy))
        db = @thunk(∇b(a,b,Δy))
        return (NoTangent(), da, db)
    end
    return y, pullback
end

function ∇a(a,b,Δy)
    println("∇a is called")
    return b * Δy
end

function ∇b(a,b,Δy)
    println("∇b is called")
    return a * Δy
end

a = 1f0
b = 2f0

ga = gradient(()->f(a,b), Flux.params(a))

which defines my custom function f (as a multiplication of 2 scalars) and defines the rrule from ChainRulesCore. In the last line, when I compute gradient w.r.t. variable a only, as ga = gradient(()->f(a,b), Flux.params(a)) , I expect to only see ∇a being called but actually I see both of ∇a and ∇b being called in the log

rrule is called
∇a is called
∇b is called

any idea why? This could be problematic when f is complicated function and it is unnecessary to call ∇b if time-consuming. Thanks for any help!

ziyiyin97 avatar Jun 21 '22 22:06 ziyiyin97

Flux.jl uses Zygote.jl, and Zygote.jl doesn't yet utilise ChainRulesCore.jl's Thunks https://github.com/FluxML/Zygote.jl/blob/9602c6b2038879034c2de14d1f4aa251d99c6ea4/src/compiler/chainrules.jl#L104

There is a WIP PR to make Zygote.jl utilise Thunks here: https://github.com/FluxML/Zygote.jl/pull/966

nickrobinson251 avatar Jun 21 '22 22:06 nickrobinson251

Thanks for your quick reply. Looking forward to the PR being merged

ziyiyin97 avatar Jun 21 '22 23:06 ziyiyin97