Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Something wrong with empty (Named-)Tuples and generators

Open axsk opened this issue 3 years ago • 17 comments

In a project of mine I want to take derivatives of some Neural SDE solution (computed by the custom wrapper msolve) wrt. to the Lux NN parameters:

function logvar(prob; ps=prob.p, n=100)  # calling this method works
    sum( msolve(prob, ps=ps) for i in 1:n)
end

Zygote.gradient(ps->logvar(prob, ps=ps, n=n), prob.p)[1] # this doesnt

fails with a

MethodError: no method matching +(::Tuple{}, ::NamedTuple{(), Tuple{}})
Stacktrace: [...]
[3] accum(x::NamedTuple{(:data, :itr), Tuple{Tuple{}, Nothing}}, y::NamedTuple{(:data, :itr), Tuple{NamedTuple{(), Tuple{}}, Nothing}})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/lib/lib.jl:27

After following the suggestion of @ToucheSir in https://github.com/FluxML/Zygote.jl/issues/1290 and replacing the generator with sum(_ -> msolve(prob, ps=ps), 1:n) the error changes to

MethodError: no method matching +(::NamedTuple{(:data, :itr), Tuple{NamedTuple{(), Tuple{}}, Nothing}}, ::NamedTuple{(:data, :itr), Tuple{NamedTuple{(), Tuple{}}, Nothing}})

I hotfixed this with

import Base.+
+(::NamedTuple{(:data, :itr), Tuple{NamedTuple{(), Tuple{}}, Nothing}}, ::NamedTuple{(:data, :itr), Tuple{NamedTuple{(), Tuple{}}, Nothing}}) = (data=(;), itr=nothing)

and the code runs through.

Searching for occurences of (:data, :itr) I could make out only https://github.com/FluxML/Zygote.jl/blob/de078c84ce0a1ee517e9e929f0bb6b97b697e23e/src/lib/base.jl#L155 and the resp. function below.

I have no clue how this all works together but thank @mcabbott and @ToucheSir a lot for helping me find the fix. Feel free to correct the issue title and let me know if I can be of any further help fixing this (regarding the Zygote internals I am quite out of my water though).

axsk avatar Aug 24 '22 19:08 axsk

Some background:

function msolve(prob; ps=prob.p, dt=0.01, salg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(), noisemixing=true))
    prob = remake(prob, p=ps)
    s = solve(prob, EM(), sensealg=salg, dt=dt)
    s[end][end]
end

essentially takes an SDE problem prop which has its RHS parametrized by a Lux.Chain with parameters ps=prob.p and returns a ::Float64 from the solutions end. I am not sure which role Lux, StochasticDiffEq or SciMLSensitivity play in this problem and will try to reduce it to a MWE when time allows (unless someone spots the problem right away :>)

axsk avatar Aug 24 '22 19:08 axsk

While you're working on a MWE, can you provide the full message and stacktrace of the latest error along with the code to run it? The gist in the linked issue appears to be out of date.

ToucheSir avatar Aug 24 '22 23:08 ToucheSir

I think I nailed it down to the remake call. When using solve(..., p=ps) instead of remake everything works out:

  • no need for the strange Base.+ hotfix
  • can use generators

I believe above methods (msolve, logvar) should suffice to reproduce the problem with any simple SDE problem with some parameter dependence (or even ODE with a corresponding solver..?). Unfortunately its past 1pm and I'm past my sworn bedtime for today, so I'll report more tomorrow.

axsk avatar Aug 24 '22 23:08 axsk

The current fixed code is here. To reproduce the error uncomment the remake line and run test()

axsk avatar Aug 24 '22 23:08 axsk

Haven't run this, but sometimes Zygote is confused by re-using the name prob. Does it happen with e.g. prob2 = remake(prob, p=ps)?

mcabbott avatar Aug 24 '22 23:08 mcabbott

remake looks...complicated: https://github.com/SciML/SciMLBase.jl/blob/9b361d6a3ea81a9e24ce14aab5768cea6986cdfe/src/remake.jl#L45. After testing Michael's suggestion, I would also be curious what you get from wrapping that remake call in Zygote.@showgrad.

ToucheSir avatar Aug 25 '22 04:08 ToucheSir

The problem persists with binding to prob2. @showgrad returns nothing

julia> test()
∂(prob2 = remake(prob, p = ps)) = nothing
∂(prob2 = remake(prob, p = ps)) = nothing
∂(prob2 = remake(prob, p = ps)) = nothing
ERROR: MethodError: no method matching +(::Tuple{}, ::NamedTuple{(), Tuple{}})

axsk avatar Aug 25 '22 09:08 axsk

Without solving where these come from, they should both probably be nothing. You could try asking _project to standardise for you:

julia> methods(Zygote._project)
# 2 methods for generic function "_project" from Zygote:
 [1] _project(x::AbstractArray, dx::Tuple)
     @ ~/.julia/packages/Zygote/xGkZ5/src/compiler/chainrules.jl:188
 [2] _project(x, dx)
     @ ~/.julia/packages/Zygote/xGkZ5/src/compiler/chainrules.jl:183

julia> Zygote._project(x, dx::NamedTuple{()}) = nothing  # shouldn't introduce ambiguities

Or perhaps adding methods to wrap_chainrules_output or something here https://github.com/FluxML/Zygote.jl/blob/master/src/compiler/chainrules.jl

These might be worth doing anyway. (https://github.com/JuliaDiff/ChainRulesCore.jl/pull/565 is something similar.) If _projectworks, you could perhaps ask it to printx` too, for clues as to what object is creating these.

mcabbott avatar Aug 26 '22 01:08 mcabbott

The problem persists with binding to prob2. @showgrad returns nothing

julia> test()
∂(prob2 = remake(prob, p = ps)) = nothing
∂(prob2 = remake(prob, p = ps)) = nothing
∂(prob2 = remake(prob, p = ps)) = nothing
ERROR: MethodError: no method matching +(::Tuple{}, ::NamedTuple{(), Tuple{}})

Sorry, I meant to wrap around just the remake(...) call and not the whole prob2 = remake(...) assignment. Wrapping the statement will always get you nothing unless it's used as a nested expression.

ToucheSir avatar Aug 26 '22 02:08 ToucheSir

I think I distilled it into a MWE

using Zygote
using StochasticDiffEq, SciMLSensitivity
import Lux

function mwe()
    x0 = rand(1)
    p0 = rand(1)

    drift(du,u,p,t) = (du .= 1)
    noise(du,u,p,t) = (du .= 1)

    prob = SDEProblem(drift, noise, x0, 1., p0)
    sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP())
    Zygote.gradient(p0) do p
        sum(Zygote.@showgrad(solve(remake(prob, p=p), EM(), dt=.1, sensealg=sensealg)[end][1]) for i in 1:3)
    end
end

With @showgrad in the correct position this now returns

julia> mwe()

∂(remake(prob, p = p)) = (f = nothing, g = nothing, u0 = [-0.09612757465640165], tspan = nothing, p = [0.0], noise = nothing, kwargs = nothing, noise_rate_prototype = nothing, seed = nothing)
∂(remake(prob, p = p)) = (f = nothing, g = nothing, u0 = [-0.06807801678762193], tspan = nothing, p = [0.0], noise = nothing, kwargs = nothing, noise_rate_prototype = nothing, seed = nothing)
∂(remake(prob, p = p)) = (f = nothing, g = nothing, u0 = [0.16420559144402358], tspan = nothing, p = [0.0], noise = nothing, kwargs = nothing, noise_rate_prototype = nothing, seed = nothing)
ERROR: MethodError: no method matching +(::Tuple{}, ::NamedTuple{(), Tuple{}})

I was surprised to see that the import Lux is necessary for the problem to occur, even though it's not being used. Without that import there is no error.

axsk avatar Aug 29 '22 18:08 axsk

That sounds a bit like piracy, which is bad.

Does there seem to be a fix involving sending () or (;) to nothing, maybe using Zygote._project as above? That would be OK even if we never find the origin.

Also, what Julia version does this MWE work on? (Failed to install everything on nightly.)

Precompiling project...
  ✗ Cassette
  ✗ SciMLSensitivity
  13 dependencies successfully precompiled in 248 seconds. 140 already precompiled.
  2 dependencies errored. To see a full report either run `import Pkg; Pkg.precompile()` or load the packages
[ Info: Precompiling StochasticDiffEq [789caeaf-c7a9-5a7d-9973-96adeb23e2a0]
[ Info: Precompiling SciMLSensitivity [1ed8b502-d754-442c-8d5d-10ac956f44a1]
Internal error: encountered unexpected error in runtime:
AssertionError(msg="argextype only works on argument-position values")
argextype at ./compiler/optimize.jl:320

mcabbott avatar Aug 29 '22 19:08 mcabbott

I tried

Zygote._project(x, dx::NamedTuple{()}) = nothing
Zygote._project(x, dx::NamedTuple{(), Tuple{}}) = nothing
Zygote._project(x, dx::Tuple{}) = nothing

all without effect I am running it on 1.8 with newest versions of the packages.

Edit:

Zygote.wrap_chainrules_output(x::NamedTuple{(), Tuple{}}) = nothing

seems to fix it.

axsk avatar Aug 29 '22 21:08 axsk

Looking into Lux.jl I found:

# Zygote Fixes
function Zygote.accum(x::ComponentArray, ys::ComponentArray...)
    return ComponentArray(Zygote.accum(getdata(x), getdata.(ys)...), getaxes(x))
end

https://github.com/avik-pal/Lux.jl/blob/11ac3e476161eedea23194b31e48e8d128950e00/src/autodiff.jl#L92

It's a pirate and touching the problematic accum, but I am not using ComponentArray in the MWE.

axsk avatar Aug 29 '22 22:08 axsk

Does adding that definition into your own code without importing Lux also break things?

On the thing which seems to fix things

Edit:

Zygote.wrap_chainrules_output(x::NamedTuple{(), Tuple{}}) = nothing

seems to fix it.

Do you mind tweaking the definition to this and pasting the stacktrace it generates here?

function Zygote.wrap_chainrules_output(x::NamedTuple{(), Tuple{}})
  display(stacktrace())
  println()
end

ToucheSir avatar Aug 29 '22 22:08 ToucheSir

The definition without Lux does not brake it, so I guess thats not the problem.

Here are the stacktraces you asked for. One should probably start at the end since the problem only occurs after the 3rd iteration.

59-element Vector{Base.StackTraces.StackFrame}: wrap_chainrules_output at REPL[5]:2 [inlined] map at tuple.jl:223 [inlined] wrap_chainrules_output at chainrules.jl:106 [inlined] ZBack at chainrules.jl:206 [inlined] Pullback at namedtuple.jl:280 [inlined] (::typeof(∂(merge)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :noise_rate_prototype, :seed), Tuple{ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, Vector{Float64}, ChainRulesCore.ZeroTangent, ChainRulesCore.NoTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent}}) at interface2.jl:0 Pullback at remake.jl:32 [inlined] (::typeof(∂(#remake#503)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0 Pullback at remake.jl:28 [inlined] (::typeof(∂(remake##kw)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0 Pullback at none:0 [inlined] (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:95 [inlined] (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:62 [inlined] (::typeof(∂(_foldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:48 [inlined] (::typeof(∂(foldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:44 [inlined] (::typeof(∂(mapfoldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:162 [inlined] Pullback at reduce.jl:162 [inlined] (::typeof(∂(mapfoldl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:294 [inlined] ⋮ (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 (::Zygote.var"#60#61"{typeof(∂(λ))})(Δ::Float64) at interface.jl:45 gradient(f::Function, args::Vector{Float64}) at interface.jl:97 mwe() at mwe1294.jl:17 top-level scope at mwe1294.jl:21 eval at boot.jl:368 [inlined] include_string(mapexpr::typeof(identity), mod::Module, code::String, filename::String) at loading.jl:1428 _include(mapexpr::Function, mod::Module, _path::String) at loading.jl:1488 include(fname::String) at client.jl:476 top-level scope at REPL[6]:1 top-level scope at initialization.jl:52 eval at boot.jl:368 [inlined] eval_user_input(ast::Any, backend::REPL.REPLBackend) at REPL.jl:151 repl_backend_loop(backend::REPL.REPLBackend) at REPL.jl:247 start_repl_backend(backend::REPL.REPLBackend, consumer::Any) at REPL.jl:232 run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool) at REPL.jl:369 run_repl(repl::REPL.AbstractREPL, consumer::Any) at REPL.jl:355 (::Base.var"#966#968"{Bool, Bool, Bool})(REPL::Module) at client.jl:419 #invokelatest#2 at essentials.jl:729 [inlined] invokelatest at essentials.jl:726 [inlined] run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool) at client.jl:404 exec_options(opts::Base.JLOptions) at client.jl:318 _start() at client.jl:522

57-element Vector{Base.StackTraces.StackFrame}: wrap_chainrules_output at REPL[5]:2 [inlined] map at tuple.jl:223 [inlined] wrap_chainrules_output at chainrules.jl:106 [inlined] (::Zygote.ZBack{Lux.var"#merge_pullback#157"{(), (:f, :g, :u0, :tspan, :p, :noise, :noise_rate_prototype, :seed)}})(dy::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :noise_rate_prototype, :seed), Tuple{ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, Vector{Float64}, ChainRulesCore.ZeroTangent, ChainRulesCore.NoTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent}}) at chainrules.jl:206 Pullback at remake.jl:32 [inlined] (::typeof(∂(#remake#503)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0 Pullback at remake.jl:28 [inlined] (::typeof(∂(remake##kw)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0 Pullback at none:0 [inlined] (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:95 [inlined] (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:62 [inlined] (::typeof(∂(_foldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:48 [inlined] (::typeof(∂(foldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:44 [inlined] (::typeof(∂(mapfoldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:162 [inlined] Pullback at reduce.jl:162 [inlined] (::typeof(∂(mapfoldl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:294 [inlined] (::typeof(∂(#mapreduce#262)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:294 [inlined] ⋮ (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 (::Zygote.var"#60#61"{typeof(∂(λ))})(Δ::Float64) at interface.jl:45 gradient(f::Function, args::Vector{Float64}) at interface.jl:97 mwe() at mwe1294.jl:17 top-level scope at mwe1294.jl:21 eval at boot.jl:368 [inlined] include_string(mapexpr::typeof(identity), mod::Module, code::String, filename::String) at loading.jl:1428 _include(mapexpr::Function, mod::Module, _path::String) at loading.jl:1488 include(fname::String) at client.jl:476 top-level scope at REPL[6]:1 top-level scope at initialization.jl:52 eval at boot.jl:368 [inlined] eval_user_input(ast::Any, backend::REPL.REPLBackend) at REPL.jl:151 repl_backend_loop(backend::REPL.REPLBackend) at REPL.jl:247 start_repl_backend(backend::REPL.REPLBackend, consumer::Any) at REPL.jl:232 run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool) at REPL.jl:369 run_repl(repl::REPL.AbstractREPL, consumer::Any) at REPL.jl:355 (::Base.var"#966#968"{Bool, Bool, Bool})(REPL::Module) at client.jl:419 #invokelatest#2 at essentials.jl:729 [inlined] invokelatest at essentials.jl:726 [inlined] run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool) at client.jl:404 exec_options(opts::Base.JLOptions) at client.jl:318 _start() at client.jl:522

59-element Vector{Base.StackTraces.StackFrame}: wrap_chainrules_output at REPL[5]:2 [inlined] map at tuple.jl:223 [inlined] wrap_chainrules_output at chainrules.jl:106 [inlined] ZBack at chainrules.jl:206 [inlined] Pullback at namedtuple.jl:280 [inlined] (::typeof(∂(merge)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :noise_rate_prototype, :seed), Tuple{ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, Vector{Float64}, ChainRulesCore.ZeroTangent, ChainRulesCore.NoTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent}}) at interface2.jl:0 Pullback at remake.jl:32 [inlined] (::typeof(∂(#remake#503)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0 Pullback at remake.jl:28 [inlined] (::typeof(∂(remake##kw)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0 Pullback at none:0 [inlined] (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:95 [inlined] (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:62 [inlined] (::typeof(∂(_foldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:48 [inlined] (::typeof(∂(foldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:44 [inlined] (::typeof(∂(mapfoldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:162 [inlined] Pullback at reduce.jl:162 [inlined] (::typeof(∂(mapfoldl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:294 [inlined] ⋮ (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 (::Zygote.var"#60#61"{typeof(∂(λ))})(Δ::Float64) at interface.jl:45 gradient(f::Function, args::Vector{Float64}) at interface.jl:97 mwe() at mwe1294.jl:17 top-level scope at mwe1294.jl:21 eval at boot.jl:368 [inlined] include_string(mapexpr::typeof(identity), mod::Module, code::String, filename::String) at loading.jl:1428 _include(mapexpr::Function, mod::Module, _path::String) at loading.jl:1488 include(fname::String) at client.jl:476 top-level scope at REPL[6]:1 top-level scope at initialization.jl:52 eval at boot.jl:368 [inlined] eval_user_input(ast::Any, backend::REPL.REPLBackend) at REPL.jl:151 repl_backend_loop(backend::REPL.REPLBackend) at REPL.jl:247 start_repl_backend(backend::REPL.REPLBackend, consumer::Any) at REPL.jl:232 run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool) at REPL.jl:369 run_repl(repl::REPL.AbstractREPL, consumer::Any) at REPL.jl:355 (::Base.var"#966#968"{Bool, Bool, Bool})(REPL::Module) at client.jl:419 #invokelatest#2 at essentials.jl:729 [inlined] invokelatest at essentials.jl:726 [inlined] run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool) at client.jl:404 exec_options(opts::Base.JLOptions) at client.jl:318 _start() at client.jl:522

57-element Vector{Base.StackTraces.StackFrame}: wrap_chainrules_output at REPL[5]:2 [inlined] map at tuple.jl:223 [inlined] wrap_chainrules_output at chainrules.jl:106 [inlined] (::Zygote.ZBack{Lux.var"#merge_pullback#157"{(), (:f, :g, :u0, :tspan, :p, :noise, :noise_rate_prototype, :seed)}})(dy::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :noise_rate_prototype, :seed), Tuple{ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, Vector{Float64}, ChainRulesCore.ZeroTangent, ChainRulesCore.NoTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent}}) at chainrules.jl:206 Pullback at remake.jl:32 [inlined] (::typeof(∂(#remake#503)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0 Pullback at remake.jl:28 [inlined] (::typeof(∂(remake##kw)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0 Pullback at none:0 [inlined] (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:95 [inlined] (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:62 [inlined] (::typeof(∂(_foldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:48 [inlined] (::typeof(∂(foldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:44 [inlined] (::typeof(∂(mapfoldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:162 [inlined] Pullback at reduce.jl:162 [inlined] (::typeof(∂(mapfoldl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:294 [inlined] (::typeof(∂(#mapreduce#262)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:294 [inlined] ⋮ (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 (::Zygote.var"#60#61"{typeof(∂(λ))})(Δ::Float64) at interface.jl:45 gradient(f::Function, args::Vector{Float64}) at interface.jl:97 mwe() at mwe1294.jl:17 top-level scope at mwe1294.jl:21 eval at boot.jl:368 [inlined] include_string(mapexpr::typeof(identity), mod::Module, code::String, filename::String) at loading.jl:1428 _include(mapexpr::Function, mod::Module, _path::String) at loading.jl:1488 include(fname::String) at client.jl:476 top-level scope at REPL[6]:1 top-level scope at initialization.jl:52 eval at boot.jl:368 [inlined] eval_user_input(ast::Any, backend::REPL.REPLBackend) at REPL.jl:151 repl_backend_loop(backend::REPL.REPLBackend) at REPL.jl:247 start_repl_backend(backend::REPL.REPLBackend, consumer::Any) at REPL.jl:232 run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool) at REPL.jl:369 run_repl(repl::REPL.AbstractREPL, consumer::Any) at REPL.jl:355 (::Base.var"#966#968"{Bool, Bool, Bool})(REPL::Module) at client.jl:419 #invokelatest#2 at essentials.jl:729 [inlined] invokelatest at essentials.jl:726 [inlined] run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool) at client.jl:404 exec_options(opts::Base.JLOptions) at client.jl:318 _start() at client.jl:522

59-element Vector{Base.StackTraces.StackFrame}: wrap_chainrules_output at REPL[5]:2 [inlined] map at tuple.jl:223 [inlined] wrap_chainrules_output at chainrules.jl:106 [inlined] ZBack at chainrules.jl:206 [inlined] Pullback at namedtuple.jl:280 [inlined] (::typeof(∂(merge)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :noise_rate_prototype, :seed), Tuple{ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, Vector{Float64}, ChainRulesCore.ZeroTangent, ChainRulesCore.NoTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent}}) at interface2.jl:0 Pullback at remake.jl:32 [inlined] (::typeof(∂(#remake#503)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0 Pullback at remake.jl:28 [inlined] (::typeof(∂(remake##kw)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0 Pullback at none:0 [inlined] (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:95 [inlined] (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:58 [inlined] (::typeof(∂(_foldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:48 [inlined] (::typeof(∂(foldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:44 [inlined] (::typeof(∂(mapfoldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:162 [inlined] Pullback at reduce.jl:162 [inlined] (::typeof(∂(mapfoldl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:294 [inlined] ⋮ (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 (::Zygote.var"#60#61"{typeof(∂(λ))})(Δ::Float64) at interface.jl:45 gradient(f::Function, args::Vector{Float64}) at interface.jl:97 mwe() at mwe1294.jl:17 top-level scope at mwe1294.jl:21 eval at boot.jl:368 [inlined] include_string(mapexpr::typeof(identity), mod::Module, code::String, filename::String) at loading.jl:1428 _include(mapexpr::Function, mod::Module, _path::String) at loading.jl:1488 include(fname::String) at client.jl:476 top-level scope at REPL[6]:1 top-level scope at initialization.jl:52 eval at boot.jl:368 [inlined] eval_user_input(ast::Any, backend::REPL.REPLBackend) at REPL.jl:151 repl_backend_loop(backend::REPL.REPLBackend) at REPL.jl:247 start_repl_backend(backend::REPL.REPLBackend, consumer::Any) at REPL.jl:232 run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool) at REPL.jl:369 run_repl(repl::REPL.AbstractREPL, consumer::Any) at REPL.jl:355 (::Base.var"#966#968"{Bool, Bool, Bool})(REPL::Module) at client.jl:419 #invokelatest#2 at essentials.jl:729 [inlined] invokelatest at essentials.jl:726 [inlined] run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool) at client.jl:404 exec_options(opts::Base.JLOptions) at client.jl:318 _start() at client.jl:522

57-element Vector{Base.StackTraces.StackFrame}: wrap_chainrules_output at REPL[5]:2 [inlined] map at tuple.jl:223 [inlined] wrap_chainrules_output at chainrules.jl:106 [inlined] (::Zygote.ZBack{Lux.var"#merge_pullback#157"{(), (:f, :g, :u0, :tspan, :p, :noise, :noise_rate_prototype, :seed)}})(dy::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :noise_rate_prototype, :seed), Tuple{ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, Vector{Float64}, ChainRulesCore.ZeroTangent, ChainRulesCore.NoTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent}}) at chainrules.jl:206 Pullback at remake.jl:32 [inlined] (::typeof(∂(#remake#503)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0 Pullback at remake.jl:28 [inlined] (::typeof(∂(remake##kw)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0 Pullback at none:0 [inlined] (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:95 [inlined] (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:58 [inlined] (::typeof(∂(_foldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:48 [inlined] (::typeof(∂(foldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:44 [inlined] (::typeof(∂(mapfoldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:162 [inlined] Pullback at reduce.jl:162 [inlined] (::typeof(∂(mapfoldl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:294 [inlined] (::typeof(∂(#mapreduce#262)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:294 [inlined] ⋮ (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 (::Zygote.var"#60#61"{typeof(∂(λ))})(Δ::Float64) at interface.jl:45 gradient(f::Function, args::Vector{Float64}) at interface.jl:97 mwe() at mwe1294.jl:17 top-level scope at mwe1294.jl:21 eval at boot.jl:368 [inlined] include_string(mapexpr::typeof(identity), mod::Module, code::String, filename::String) at loading.jl:1428 _include(mapexpr::Function, mod::Module, _path::String) at loading.jl:1488 include(fname::String) at client.jl:476 top-level scope at REPL[6]:1 top-level scope at initialization.jl:52 eval at boot.jl:368 [inlined] eval_user_input(ast::Any, backend::REPL.REPLBackend) at REPL.jl:151 repl_backend_loop(backend::REPL.REPLBackend) at REPL.jl:247 start_repl_backend(backend::REPL.REPLBackend, consumer::Any) at REPL.jl:232 run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool) at REPL.jl:369 run_repl(repl::REPL.AbstractREPL, consumer::Any) at REPL.jl:355 (::Base.var"#966#968"{Bool, Bool, Bool})(REPL::Module) at client.jl:419 #invokelatest#2 at essentials.jl:729 [inlined] invokelatest at essentials.jl:726 [inlined] run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool) at client.jl:404 exec_options(opts::Base.JLOptions) at client.jl:318 _start() at client.jl:522

axsk avatar Aug 29 '22 23:08 axsk

Thanks! The last stacktrace includes https://github.com/avik-pal/Lux.jl/blob/11ac3e476161eedea23194b31e48e8d128950e00/src/autodiff.jl#L53-L63, which is very much piracy. Is that the last stacktrace printed before the error? If so, can you see if that rrule overload breaks things?

ToucheSir avatar Aug 29 '22 23:08 ToucheSir

After removing the lines in question the test runs through :saxophone:

axsk avatar Aug 31 '22 13:08 axsk