Zygote.jl
Zygote.jl copied to clipboard
Differentiating through pmap with a closure fails
This issue is basically identical to the one described here. I was able to verify that the MWE below works on Julia 1.4 with Zygote 0.5.4, as described by the last post in that thread. However, it fails on Julia 1.6 with Zygote 0.6.12.
Here's the MWE (copy/pasted from the discussion linked above):
using Distributed
addprocs(4, enable_threaded_blas=true) # fails with enable_threaded_blas=false as well
@everywhere using Zygote
@everywhere begin
function ∇pmap(cx, wp, f, args...)
ys_and_backs = pmap((args...) -> Zygote._pullback(cx, f, args...), wp, args...)
if isempty(ys_and_backs)
ys_and_backs, _ -> nothing
else
ys, backs = Zygote.unzip(ys_and_backs)
ys, function (Δ)
Δf_and_args_zipped = pmap((f, δ) -> f(δ), wp, backs, Δ)
Δf_and_args = Zygote.unzip(Δf_and_args_zipped)
Δf = reduce(Zygote.accum, Δf_and_args[1])
(Δf, nothing, Δf_and_args[2:end]...)
end
end
end
Zygote.@adjoint function pmap(f, wp, args::Union{AbstractArray,Tuple}...)
∇pmap(__context__, wp, f, args...)
end
function test_grad(x)
return x^2 + 3log(x)
end
function test_grad_pmap(x)
return sum(pmap(test_grad, wp, x))
end
end
wp = default_worker_pool()
Zygote.gradient(test_grad, 1.0) # works
Zygote.gradient(test_grad_pmap, rand(100)) # breaks
Zygote.gradient(x -> sum(pmap(y -> y^2, wp, x)), rand(100)) # breaks
Zygote.gradient(x -> sum(pmap(sum, wp, x)), rand(100)) # also breaks on julia 1.7 with Zygote 0.6.12
Any help getting this working would be greatly appreciated!
Yeah I just ran into the exact same issue today. Julia 1.6.1 and Zygote 0.6.12.
Can you explain why you need to define your own ∇pmap
? Is this a stand-in for some other function with a custom gradient you need to define everywhere? If you use the built-in definition, it appears to work... unless there is some other problem with what it's returning?
julia> using Distributed
julia> addprocs(4, enable_threaded_blas=true); # fails with enable_threaded_blas=false as well
julia> @everywhere using Zygote
julia> @everywhere begin # ONLY the test function
function test_grad(x)
return x^2 + 3log(x)
end
function test_grad_pmap(x)
return sum(pmap(test_grad, wp, x))
end
end
julia> wp = default_worker_pool()
WorkerPool(Channel{Int64}(9223372036854775807), Set([5, 4, 2, 3]), RemoteChannel{Channel{Any}}(1, 1, 5))
julia> Zygote.gradient(test_grad, 1.0) # works
(5.0,)
julia> Zygote.gradient(test_grad_pmap, rand(100)) # breaks
([5.380331174182015, 40.16711535034479, ...
julia> Zygote.gradient(x -> sum(pmap(y -> y^2, wp, x)), rand(100)) # breaks
([0.4737786213199344, 0.8461407747791889, ...
julia> Zygote.gradient(x -> sum(pmap(sum, wp, x)), rand(100)) # also breaks on julia 1.7 with Zygote 0.6.12
([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 … 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],)
Whereas pasting exactly the code above does break, with the following error:
julia> Zygote.gradient(test_grad_pmap, rand(100)) # breaks
ERROR: MethodError: _pullback(::Zygote.Context, ::typeof(pmap), ::typeof(test_grad), ::WorkerPool, ::Vector{Float64}) is ambiguous. Candidates:
_pullback(__context__::ZygoteRules.AContext, var"529"::typeof(pmap), f, wp::AbstractWorkerPool, args...) in Zygote at /Users/me/.julia/dev/ZygoteRules/src/adjoint.jl:59
_pullback(__context__::ZygoteRules.AContext, var"547"::typeof(pmap), f, wp, args::Union{Tuple, AbstractArray}...) in Main at /Users/me/.julia/dev/ZygoteRules/src/adjoint.jl:59
Possible fix, define
_pullback(::ZygoteRules.AContext, ::typeof(pmap), ::Any, ::AbstractWorkerPool, ::Vararg{Union{Tuple, AbstractArray}, N} where N)
Stacktrace:
[1] _pullback
@ ./REPL[51]:25 [inlined]
[2] _pullback(ctx::Zygote.Context, f::typeof(test_grad_pmap), args::Vector{Float64})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[3] _pullback(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:34
[4] pullback(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:40
[5] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:58
[6] top-level scope
@ REPL[54]:1
Unlike the original discourse, I'm not defining my own pmap
or its gradient. For me it is just erroring using the released Zygote with the same message as the discourse (generated name being different of course):
ERROR: LoadError: UndefVarError: ##493#back#178 not defined
In my case the code is more complex. I had assumed based on the error result that the same MWE would suffice. I'll try to see if the posted MWE errors for me, and if not, come up with a new MWE.
I'm basically in the same boat as @darsnack. I have a complex code that doesn't define ∇pmap
. I thought that discourse thread was a quick and easy way to provide an MWE, but that apparently isn't a good MWE. My bad. I'll also see if I can construct an MWE, but I won't have time for a while.
I am having a similar issue. I am trying to use pmap inside a loss function and get an error when trying to obtain the gradients using Zygote.pullback:
ERROR: LoadError: Compiling Tuple{typeof(lock), Base.var"#556#557"{WeakKeyDict{Distributed.AbstractRemoteRef, Nothing}, Future}, ReentrantLock}: try/catch is not supported. Refer to the Zygote documentation for fixes. https://fluxml.ai/Zygote.jl/dev/limitations.html#Try-catch-statements-1