Something wrong with empty (Named-)Tuples and generators
In a project of mine I want to take derivatives of some Neural SDE solution (computed by the custom wrapper msolve) wrt. to the Lux NN parameters:
function logvar(prob; ps=prob.p, n=100) # calling this method works
sum( msolve(prob, ps=ps) for i in 1:n)
end
Zygote.gradient(ps->logvar(prob, ps=ps, n=n), prob.p)[1] # this doesnt
fails with a
MethodError: no method matching +(::Tuple{}, ::NamedTuple{(), Tuple{}})
Stacktrace: [...]
[3] accum(x::NamedTuple{(:data, :itr), Tuple{Tuple{}, Nothing}}, y::NamedTuple{(:data, :itr), Tuple{NamedTuple{(), Tuple{}}, Nothing}})
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/lib/lib.jl:27
After following the suggestion of @ToucheSir in https://github.com/FluxML/Zygote.jl/issues/1290 and replacing the generator with sum(_ -> msolve(prob, ps=ps), 1:n) the error changes to
MethodError: no method matching +(::NamedTuple{(:data, :itr), Tuple{NamedTuple{(), Tuple{}}, Nothing}}, ::NamedTuple{(:data, :itr), Tuple{NamedTuple{(), Tuple{}}, Nothing}})
I hotfixed this with
import Base.+
+(::NamedTuple{(:data, :itr), Tuple{NamedTuple{(), Tuple{}}, Nothing}}, ::NamedTuple{(:data, :itr), Tuple{NamedTuple{(), Tuple{}}, Nothing}}) = (data=(;), itr=nothing)
and the code runs through.
Searching for occurences of (:data, :itr) I could make out only https://github.com/FluxML/Zygote.jl/blob/de078c84ce0a1ee517e9e929f0bb6b97b697e23e/src/lib/base.jl#L155
and the resp. function below.
I have no clue how this all works together but thank @mcabbott and @ToucheSir a lot for helping me find the fix. Feel free to correct the issue title and let me know if I can be of any further help fixing this (regarding the Zygote internals I am quite out of my water though).
Some background:
function msolve(prob; ps=prob.p, dt=0.01, salg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(), noisemixing=true))
prob = remake(prob, p=ps)
s = solve(prob, EM(), sensealg=salg, dt=dt)
s[end][end]
end
essentially takes an SDE problem prop which has its RHS parametrized by a Lux.Chain with parameters ps=prob.p and returns a ::Float64 from the solutions end.
I am not sure which role Lux, StochasticDiffEq or SciMLSensitivity play in this problem and will try to reduce it to a MWE when time allows (unless someone spots the problem right away :>)
While you're working on a MWE, can you provide the full message and stacktrace of the latest error along with the code to run it? The gist in the linked issue appears to be out of date.
I think I nailed it down to the remake call. When using solve(..., p=ps) instead of remake everything works out:
- no need for the strange
Base.+hotfix - can use generators
I believe above methods (msolve, logvar) should suffice to reproduce the problem with any simple SDE problem with some parameter dependence (or even ODE with a corresponding solver..?). Unfortunately its past 1pm and I'm past my sworn bedtime for today, so I'll report more tomorrow.
The current fixed code is here. To reproduce the error uncomment the remake line and run test()
Haven't run this, but sometimes Zygote is confused by re-using the name prob. Does it happen with e.g. prob2 = remake(prob, p=ps)?
remake looks...complicated: https://github.com/SciML/SciMLBase.jl/blob/9b361d6a3ea81a9e24ce14aab5768cea6986cdfe/src/remake.jl#L45. After testing Michael's suggestion, I would also be curious what you get from wrapping that remake call in Zygote.@showgrad.
The problem persists with binding to prob2. @showgrad returns nothing
julia> test()
∂(prob2 = remake(prob, p = ps)) = nothing
∂(prob2 = remake(prob, p = ps)) = nothing
∂(prob2 = remake(prob, p = ps)) = nothing
ERROR: MethodError: no method matching +(::Tuple{}, ::NamedTuple{(), Tuple{}})
Without solving where these come from, they should both probably be nothing. You could try asking _project to standardise for you:
julia> methods(Zygote._project)
# 2 methods for generic function "_project" from Zygote:
[1] _project(x::AbstractArray, dx::Tuple)
@ ~/.julia/packages/Zygote/xGkZ5/src/compiler/chainrules.jl:188
[2] _project(x, dx)
@ ~/.julia/packages/Zygote/xGkZ5/src/compiler/chainrules.jl:183
julia> Zygote._project(x, dx::NamedTuple{()}) = nothing # shouldn't introduce ambiguities
Or perhaps adding methods to wrap_chainrules_output or something here https://github.com/FluxML/Zygote.jl/blob/master/src/compiler/chainrules.jl
These might be worth doing anyway. (https://github.com/JuliaDiff/ChainRulesCore.jl/pull/565 is something similar.) If _projectworks, you could perhaps ask it to printx` too, for clues as to what object is creating these.
The problem persists with binding to
prob2.@showgradreturnsnothingjulia> test() ∂(prob2 = remake(prob, p = ps)) = nothing ∂(prob2 = remake(prob, p = ps)) = nothing ∂(prob2 = remake(prob, p = ps)) = nothing ERROR: MethodError: no method matching +(::Tuple{}, ::NamedTuple{(), Tuple{}})
Sorry, I meant to wrap around just the remake(...) call and not the whole prob2 = remake(...) assignment. Wrapping the statement will always get you nothing unless it's used as a nested expression.
I think I distilled it into a MWE
using Zygote
using StochasticDiffEq, SciMLSensitivity
import Lux
function mwe()
x0 = rand(1)
p0 = rand(1)
drift(du,u,p,t) = (du .= 1)
noise(du,u,p,t) = (du .= 1)
prob = SDEProblem(drift, noise, x0, 1., p0)
sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP())
Zygote.gradient(p0) do p
sum(Zygote.@showgrad(solve(remake(prob, p=p), EM(), dt=.1, sensealg=sensealg)[end][1]) for i in 1:3)
end
end
With @showgrad in the correct position this now returns
julia> mwe()
∂(remake(prob, p = p)) = (f = nothing, g = nothing, u0 = [-0.09612757465640165], tspan = nothing, p = [0.0], noise = nothing, kwargs = nothing, noise_rate_prototype = nothing, seed = nothing)
∂(remake(prob, p = p)) = (f = nothing, g = nothing, u0 = [-0.06807801678762193], tspan = nothing, p = [0.0], noise = nothing, kwargs = nothing, noise_rate_prototype = nothing, seed = nothing)
∂(remake(prob, p = p)) = (f = nothing, g = nothing, u0 = [0.16420559144402358], tspan = nothing, p = [0.0], noise = nothing, kwargs = nothing, noise_rate_prototype = nothing, seed = nothing)
ERROR: MethodError: no method matching +(::Tuple{}, ::NamedTuple{(), Tuple{}})
I was surprised to see that the import Lux is necessary for the problem to occur, even though it's not being used.
Without that import there is no error.
That sounds a bit like piracy, which is bad.
Does there seem to be a fix involving sending () or (;) to nothing, maybe using Zygote._project as above? That would be OK even if we never find the origin.
Also, what Julia version does this MWE work on? (Failed to install everything on nightly.)
Precompiling project...
✗ Cassette
✗ SciMLSensitivity
13 dependencies successfully precompiled in 248 seconds. 140 already precompiled.
2 dependencies errored. To see a full report either run `import Pkg; Pkg.precompile()` or load the packages
[ Info: Precompiling StochasticDiffEq [789caeaf-c7a9-5a7d-9973-96adeb23e2a0]
[ Info: Precompiling SciMLSensitivity [1ed8b502-d754-442c-8d5d-10ac956f44a1]
Internal error: encountered unexpected error in runtime:
AssertionError(msg="argextype only works on argument-position values")
argextype at ./compiler/optimize.jl:320
I tried
Zygote._project(x, dx::NamedTuple{()}) = nothing
Zygote._project(x, dx::NamedTuple{(), Tuple{}}) = nothing
Zygote._project(x, dx::Tuple{}) = nothing
all without effect I am running it on 1.8 with newest versions of the packages.
Edit:
Zygote.wrap_chainrules_output(x::NamedTuple{(), Tuple{}}) = nothing
seems to fix it.
Looking into Lux.jl I found:
# Zygote Fixes
function Zygote.accum(x::ComponentArray, ys::ComponentArray...)
return ComponentArray(Zygote.accum(getdata(x), getdata.(ys)...), getaxes(x))
end
https://github.com/avik-pal/Lux.jl/blob/11ac3e476161eedea23194b31e48e8d128950e00/src/autodiff.jl#L92
It's a pirate and touching the problematic accum, but I am not using ComponentArray in the MWE.
Does adding that definition into your own code without importing Lux also break things?
On the thing which seems to fix things
Edit:
Zygote.wrap_chainrules_output(x::NamedTuple{(), Tuple{}}) = nothingseems to fix it.
Do you mind tweaking the definition to this and pasting the stacktrace it generates here?
function Zygote.wrap_chainrules_output(x::NamedTuple{(), Tuple{}})
display(stacktrace())
println()
end
The definition without Lux does not brake it, so I guess thats not the problem.
Here are the stacktraces you asked for. One should probably start at the end since the problem only occurs after the 3rd iteration.
57-element Vector{Base.StackTraces.StackFrame}: wrap_chainrules_output at REPL[5]:2 [inlined] map at tuple.jl:223 [inlined] wrap_chainrules_output at chainrules.jl:106 [inlined] (::Zygote.ZBack{Lux.var"#merge_pullback#157"{(), (:f, :g, :u0, :tspan, :p, :noise, :noise_rate_prototype, :seed)}})(dy::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :noise_rate_prototype, :seed), Tuple{ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, Vector{Float64}, ChainRulesCore.ZeroTangent, ChainRulesCore.NoTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent}}) at chainrules.jl:206 Pullback at remake.jl:32 [inlined] (::typeof(∂(#remake#503)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0 Pullback at remake.jl:28 [inlined] (::typeof(∂(remake##kw)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0 Pullback at none:0 [inlined] (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:95 [inlined] (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:62 [inlined] (::typeof(∂(_foldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:48 [inlined] (::typeof(∂(foldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:44 [inlined] (::typeof(∂(mapfoldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:162 [inlined] Pullback at reduce.jl:162 [inlined] (::typeof(∂(mapfoldl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:294 [inlined] (::typeof(∂(#mapreduce#262)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:294 [inlined] ⋮ (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 (::Zygote.var"#60#61"{typeof(∂(λ))})(Δ::Float64) at interface.jl:45 gradient(f::Function, args::Vector{Float64}) at interface.jl:97 mwe() at mwe1294.jl:17 top-level scope at mwe1294.jl:21 eval at boot.jl:368 [inlined] include_string(mapexpr::typeof(identity), mod::Module, code::String, filename::String) at loading.jl:1428 _include(mapexpr::Function, mod::Module, _path::String) at loading.jl:1488 include(fname::String) at client.jl:476 top-level scope at REPL[6]:1 top-level scope at initialization.jl:52 eval at boot.jl:368 [inlined] eval_user_input(ast::Any, backend::REPL.REPLBackend) at REPL.jl:151 repl_backend_loop(backend::REPL.REPLBackend) at REPL.jl:247 start_repl_backend(backend::REPL.REPLBackend, consumer::Any) at REPL.jl:232 run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool) at REPL.jl:369 run_repl(repl::REPL.AbstractREPL, consumer::Any) at REPL.jl:355 (::Base.var"#966#968"{Bool, Bool, Bool})(REPL::Module) at client.jl:419 #invokelatest#2 at essentials.jl:729 [inlined] invokelatest at essentials.jl:726 [inlined] run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool) at client.jl:404 exec_options(opts::Base.JLOptions) at client.jl:318 _start() at client.jl:522
59-element Vector{Base.StackTraces.StackFrame}: wrap_chainrules_output at REPL[5]:2 [inlined] map at tuple.jl:223 [inlined] wrap_chainrules_output at chainrules.jl:106 [inlined] ZBack at chainrules.jl:206 [inlined] Pullback at namedtuple.jl:280 [inlined] (::typeof(∂(merge)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :noise_rate_prototype, :seed), Tuple{ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, Vector{Float64}, ChainRulesCore.ZeroTangent, ChainRulesCore.NoTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent}}) at interface2.jl:0 Pullback at remake.jl:32 [inlined] (::typeof(∂(#remake#503)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0 Pullback at remake.jl:28 [inlined] (::typeof(∂(remake##kw)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0 Pullback at none:0 [inlined] (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:95 [inlined] (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:62 [inlined] (::typeof(∂(_foldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:48 [inlined] (::typeof(∂(foldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:44 [inlined] (::typeof(∂(mapfoldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:162 [inlined] Pullback at reduce.jl:162 [inlined] (::typeof(∂(mapfoldl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:294 [inlined] ⋮ (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 (::Zygote.var"#60#61"{typeof(∂(λ))})(Δ::Float64) at interface.jl:45 gradient(f::Function, args::Vector{Float64}) at interface.jl:97 mwe() at mwe1294.jl:17 top-level scope at mwe1294.jl:21 eval at boot.jl:368 [inlined] include_string(mapexpr::typeof(identity), mod::Module, code::String, filename::String) at loading.jl:1428 _include(mapexpr::Function, mod::Module, _path::String) at loading.jl:1488 include(fname::String) at client.jl:476 top-level scope at REPL[6]:1 top-level scope at initialization.jl:52 eval at boot.jl:368 [inlined] eval_user_input(ast::Any, backend::REPL.REPLBackend) at REPL.jl:151 repl_backend_loop(backend::REPL.REPLBackend) at REPL.jl:247 start_repl_backend(backend::REPL.REPLBackend, consumer::Any) at REPL.jl:232 run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool) at REPL.jl:369 run_repl(repl::REPL.AbstractREPL, consumer::Any) at REPL.jl:355 (::Base.var"#966#968"{Bool, Bool, Bool})(REPL::Module) at client.jl:419 #invokelatest#2 at essentials.jl:729 [inlined] invokelatest at essentials.jl:726 [inlined] run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool) at client.jl:404 exec_options(opts::Base.JLOptions) at client.jl:318 _start() at client.jl:522
57-element Vector{Base.StackTraces.StackFrame}: wrap_chainrules_output at REPL[5]:2 [inlined] map at tuple.jl:223 [inlined] wrap_chainrules_output at chainrules.jl:106 [inlined] (::Zygote.ZBack{Lux.var"#merge_pullback#157"{(), (:f, :g, :u0, :tspan, :p, :noise, :noise_rate_prototype, :seed)}})(dy::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :noise_rate_prototype, :seed), Tuple{ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, Vector{Float64}, ChainRulesCore.ZeroTangent, ChainRulesCore.NoTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent}}) at chainrules.jl:206 Pullback at remake.jl:32 [inlined] (::typeof(∂(#remake#503)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0 Pullback at remake.jl:28 [inlined] (::typeof(∂(remake##kw)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0 Pullback at none:0 [inlined] (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:95 [inlined] (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:62 [inlined] (::typeof(∂(_foldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:48 [inlined] (::typeof(∂(foldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:44 [inlined] (::typeof(∂(mapfoldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:162 [inlined] Pullback at reduce.jl:162 [inlined] (::typeof(∂(mapfoldl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:294 [inlined] (::typeof(∂(#mapreduce#262)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:294 [inlined] ⋮ (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 (::Zygote.var"#60#61"{typeof(∂(λ))})(Δ::Float64) at interface.jl:45 gradient(f::Function, args::Vector{Float64}) at interface.jl:97 mwe() at mwe1294.jl:17 top-level scope at mwe1294.jl:21 eval at boot.jl:368 [inlined] include_string(mapexpr::typeof(identity), mod::Module, code::String, filename::String) at loading.jl:1428 _include(mapexpr::Function, mod::Module, _path::String) at loading.jl:1488 include(fname::String) at client.jl:476 top-level scope at REPL[6]:1 top-level scope at initialization.jl:52 eval at boot.jl:368 [inlined] eval_user_input(ast::Any, backend::REPL.REPLBackend) at REPL.jl:151 repl_backend_loop(backend::REPL.REPLBackend) at REPL.jl:247 start_repl_backend(backend::REPL.REPLBackend, consumer::Any) at REPL.jl:232 run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool) at REPL.jl:369 run_repl(repl::REPL.AbstractREPL, consumer::Any) at REPL.jl:355 (::Base.var"#966#968"{Bool, Bool, Bool})(REPL::Module) at client.jl:419 #invokelatest#2 at essentials.jl:729 [inlined] invokelatest at essentials.jl:726 [inlined] run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool) at client.jl:404 exec_options(opts::Base.JLOptions) at client.jl:318 _start() at client.jl:522
59-element Vector{Base.StackTraces.StackFrame}: wrap_chainrules_output at REPL[5]:2 [inlined] map at tuple.jl:223 [inlined] wrap_chainrules_output at chainrules.jl:106 [inlined] ZBack at chainrules.jl:206 [inlined] Pullback at namedtuple.jl:280 [inlined] (::typeof(∂(merge)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :noise_rate_prototype, :seed), Tuple{ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, Vector{Float64}, ChainRulesCore.ZeroTangent, ChainRulesCore.NoTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent}}) at interface2.jl:0 Pullback at remake.jl:32 [inlined] (::typeof(∂(#remake#503)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0 Pullback at remake.jl:28 [inlined] (::typeof(∂(remake##kw)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0 Pullback at none:0 [inlined] (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:95 [inlined] (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:58 [inlined] (::typeof(∂(_foldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:48 [inlined] (::typeof(∂(foldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:44 [inlined] (::typeof(∂(mapfoldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:162 [inlined] Pullback at reduce.jl:162 [inlined] (::typeof(∂(mapfoldl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:294 [inlined] ⋮ (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 (::Zygote.var"#60#61"{typeof(∂(λ))})(Δ::Float64) at interface.jl:45 gradient(f::Function, args::Vector{Float64}) at interface.jl:97 mwe() at mwe1294.jl:17 top-level scope at mwe1294.jl:21 eval at boot.jl:368 [inlined] include_string(mapexpr::typeof(identity), mod::Module, code::String, filename::String) at loading.jl:1428 _include(mapexpr::Function, mod::Module, _path::String) at loading.jl:1488 include(fname::String) at client.jl:476 top-level scope at REPL[6]:1 top-level scope at initialization.jl:52 eval at boot.jl:368 [inlined] eval_user_input(ast::Any, backend::REPL.REPLBackend) at REPL.jl:151 repl_backend_loop(backend::REPL.REPLBackend) at REPL.jl:247 start_repl_backend(backend::REPL.REPLBackend, consumer::Any) at REPL.jl:232 run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool) at REPL.jl:369 run_repl(repl::REPL.AbstractREPL, consumer::Any) at REPL.jl:355 (::Base.var"#966#968"{Bool, Bool, Bool})(REPL::Module) at client.jl:419 #invokelatest#2 at essentials.jl:729 [inlined] invokelatest at essentials.jl:726 [inlined] run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool) at client.jl:404 exec_options(opts::Base.JLOptions) at client.jl:318 _start() at client.jl:522
57-element Vector{Base.StackTraces.StackFrame}: wrap_chainrules_output at REPL[5]:2 [inlined] map at tuple.jl:223 [inlined] wrap_chainrules_output at chainrules.jl:106 [inlined] (::Zygote.ZBack{Lux.var"#merge_pullback#157"{(), (:f, :g, :u0, :tspan, :p, :noise, :noise_rate_prototype, :seed)}})(dy::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :noise_rate_prototype, :seed), Tuple{ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, Vector{Float64}, ChainRulesCore.ZeroTangent, ChainRulesCore.NoTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent, ChainRulesCore.ZeroTangent}}) at chainrules.jl:206 Pullback at remake.jl:32 [inlined] (::typeof(∂(#remake#503)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0 Pullback at remake.jl:28 [inlined] (::typeof(∂(remake##kw)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float64}, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Nothing}}) at interface2.jl:0 Pullback at none:0 [inlined] (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:95 [inlined] (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:58 [inlined] (::typeof(∂(_foldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:48 [inlined] (::typeof(∂(foldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:44 [inlined] (::typeof(∂(mapfoldl_impl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:162 [inlined] Pullback at reduce.jl:162 [inlined] (::typeof(∂(mapfoldl)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:294 [inlined] (::typeof(∂(#mapreduce#262)))(Δ::Float64) at interface2.jl:0 Pullback at reduce.jl:294 [inlined] ⋮ (::typeof(∂(λ)))(Δ::Float64) at interface2.jl:0 (::Zygote.var"#60#61"{typeof(∂(λ))})(Δ::Float64) at interface.jl:45 gradient(f::Function, args::Vector{Float64}) at interface.jl:97 mwe() at mwe1294.jl:17 top-level scope at mwe1294.jl:21 eval at boot.jl:368 [inlined] include_string(mapexpr::typeof(identity), mod::Module, code::String, filename::String) at loading.jl:1428 _include(mapexpr::Function, mod::Module, _path::String) at loading.jl:1488 include(fname::String) at client.jl:476 top-level scope at REPL[6]:1 top-level scope at initialization.jl:52 eval at boot.jl:368 [inlined] eval_user_input(ast::Any, backend::REPL.REPLBackend) at REPL.jl:151 repl_backend_loop(backend::REPL.REPLBackend) at REPL.jl:247 start_repl_backend(backend::REPL.REPLBackend, consumer::Any) at REPL.jl:232 run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool) at REPL.jl:369 run_repl(repl::REPL.AbstractREPL, consumer::Any) at REPL.jl:355 (::Base.var"#966#968"{Bool, Bool, Bool})(REPL::Module) at client.jl:419 #invokelatest#2 at essentials.jl:729 [inlined] invokelatest at essentials.jl:726 [inlined] run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool) at client.jl:404 exec_options(opts::Base.JLOptions) at client.jl:318 _start() at client.jl:522
Thanks! The last stacktrace includes https://github.com/avik-pal/Lux.jl/blob/11ac3e476161eedea23194b31e48e8d128950e00/src/autodiff.jl#L53-L63, which is very much piracy. Is that the last stacktrace printed before the error? If so, can you see if that rrule overload breaks things?
After removing the lines in question the test runs through :saxophone: