ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

stack overflow issue

Open jakubMitura14 opened this issue 1 year ago • 0 comments

Hello I have a memory-constrained problem with a Lux.jl model that uses Zygote for most of the backpropagation.

I tried to approach this from chainrules perspective I need to checkpoint each Lux.jl layer in neural network. So I tried to achieve it like that :

function ChainRulesCore.rrule(::typeof(Lux.apply), l::Lux.AbstractExplicitLayer, x, ps, st)
    y = Lux.apply(l, x, ps, st)
    
    function pullback_checkpointed(Δy)
        y, pb =Zygote.pullback(Lux.apply,l, x, ps, st) 
        return NoTangent(), pb(Δy)
    end
    
    y, pullback_checkpointed
end

Rule gets invoked in backpropagation Hovewer the issue is that for some reason it try recursively to do backpropagation of the first line

 y = Lux.apply(l, x, ps, st)

so I get stack overflow error; how to correct it?

I had also posted this issue in https://discourse.julialang.org/t/avoid-storing-intermediate-results-from-the-forward-pass-by-default/119694/4?u=jakub_mitura

jakubMitura14 avatar Sep 26 '24 18:09 jakubMitura14