Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Differentiating through pmap with a closure fails

Open omalled opened this issue 3 years ago • 5 comments

This issue is basically identical to the one described here. I was able to verify that the MWE below works on Julia 1.4 with Zygote 0.5.4, as described by the last post in that thread. However, it fails on Julia 1.6 with Zygote 0.6.12.

Here's the MWE (copy/pasted from the discussion linked above):

using Distributed
addprocs(4, enable_threaded_blas=true) # fails with enable_threaded_blas=false as well

@everywhere using Zygote

@everywhere begin
  function ∇pmap(cx, wp, f, args...)
    ys_and_backs = pmap((args...) -> Zygote._pullback(cx, f, args...), wp, args...)
    if isempty(ys_and_backs)
      ys_and_backs, _ -> nothing
    else
      ys, backs = Zygote.unzip(ys_and_backs)
      ys, function (Δ)
        Δf_and_args_zipped = pmap((f, δ) -> f(δ), wp, backs, Δ)
        Δf_and_args = Zygote.unzip(Δf_and_args_zipped)
        Δf = reduce(Zygote.accum, Δf_and_args[1])
        (Δf, nothing, Δf_and_args[2:end]...)
      end
    end
  end

  Zygote.@adjoint function pmap(f, wp, args::Union{AbstractArray,Tuple}...)
    ∇pmap(__context__, wp, f, args...)
  end

  function test_grad(x)
    return x^2 + 3log(x)
  end
  function test_grad_pmap(x)
    return sum(pmap(test_grad, wp, x))
  end
end

wp = default_worker_pool()
Zygote.gradient(test_grad, 1.0) # works
Zygote.gradient(test_grad_pmap, rand(100)) # breaks
Zygote.gradient(x -> sum(pmap(y -> y^2, wp, x)), rand(100)) # breaks
Zygote.gradient(x -> sum(pmap(sum, wp, x)), rand(100)) # also breaks on julia 1.7 with Zygote 0.6.12

Any help getting this working would be greatly appreciated!

omalled avatar Jun 16 '21 20:06 omalled

Yeah I just ran into the exact same issue today. Julia 1.6.1 and Zygote 0.6.12.

darsnack avatar Jun 16 '21 21:06 darsnack

Can you explain why you need to define your own ∇pmap? Is this a stand-in for some other function with a custom gradient you need to define everywhere? If you use the built-in definition, it appears to work... unless there is some other problem with what it's returning?

julia> using Distributed

julia> addprocs(4, enable_threaded_blas=true); # fails with enable_threaded_blas=false as well

julia> @everywhere using Zygote

julia> @everywhere begin  # ONLY the test function
         function test_grad(x)
           return x^2 + 3log(x)
         end
         function test_grad_pmap(x)
           return sum(pmap(test_grad, wp, x))
         end
         end

julia> wp = default_worker_pool()
WorkerPool(Channel{Int64}(9223372036854775807), Set([5, 4, 2, 3]), RemoteChannel{Channel{Any}}(1, 1, 5))

julia> Zygote.gradient(test_grad, 1.0) # works
(5.0,)

julia> Zygote.gradient(test_grad_pmap, rand(100)) # breaks
([5.380331174182015, 40.16711535034479, ...

julia> Zygote.gradient(x -> sum(pmap(y -> y^2, wp, x)), rand(100)) # breaks
([0.4737786213199344, 0.8461407747791889, ...

julia> Zygote.gradient(x -> sum(pmap(sum, wp, x)), rand(100)) # also breaks on julia 1.7 with Zygote 0.6.12
([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0  …  1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],)

Whereas pasting exactly the code above does break, with the following error:

julia> Zygote.gradient(test_grad_pmap, rand(100)) # breaks
ERROR: MethodError: _pullback(::Zygote.Context, ::typeof(pmap), ::typeof(test_grad), ::WorkerPool, ::Vector{Float64}) is ambiguous. Candidates:
  _pullback(__context__::ZygoteRules.AContext, var"529"::typeof(pmap), f, wp::AbstractWorkerPool, args...) in Zygote at /Users/me/.julia/dev/ZygoteRules/src/adjoint.jl:59
  _pullback(__context__::ZygoteRules.AContext, var"547"::typeof(pmap), f, wp, args::Union{Tuple, AbstractArray}...) in Main at /Users/me/.julia/dev/ZygoteRules/src/adjoint.jl:59
Possible fix, define
  _pullback(::ZygoteRules.AContext, ::typeof(pmap), ::Any, ::AbstractWorkerPool, ::Vararg{Union{Tuple, AbstractArray}, N} where N)
Stacktrace:
 [1] _pullback
   @ ./REPL[51]:25 [inlined]
 [2] _pullback(ctx::Zygote.Context, f::typeof(test_grad_pmap), args::Vector{Float64})
   @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [3] _pullback(f::Function, args::Vector{Float64})
   @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:34
 [4] pullback(f::Function, args::Vector{Float64})
   @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:40
 [5] gradient(f::Function, args::Vector{Float64})
   @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:58
 [6] top-level scope
   @ REPL[54]:1

mcabbott avatar Jun 18 '21 15:06 mcabbott

Unlike the original discourse, I'm not defining my own pmap or its gradient. For me it is just erroring using the released Zygote with the same message as the discourse (generated name being different of course):

ERROR: LoadError: UndefVarError: ##493#back#178 not defined

In my case the code is more complex. I had assumed based on the error result that the same MWE would suffice. I'll try to see if the posted MWE errors for me, and if not, come up with a new MWE.

darsnack avatar Jun 18 '21 16:06 darsnack

I'm basically in the same boat as @darsnack. I have a complex code that doesn't define ∇pmap. I thought that discourse thread was a quick and easy way to provide an MWE, but that apparently isn't a good MWE. My bad. I'll also see if I can construct an MWE, but I won't have time for a while.

omalled avatar Jun 18 '21 18:06 omalled

I am having a similar issue. I am trying to use pmap inside a loss function and get an error when trying to obtain the gradients using Zygote.pullback:

ERROR: LoadError: Compiling Tuple{typeof(lock), Base.var"#556#557"{WeakKeyDict{Distributed.AbstractRemoteRef, Nothing}, Future}, ReentrantLock}: try/catch is not supported. Refer to the Zygote documentation for fixes. https://fluxml.ai/Zygote.jl/dev/limitations.html#Try-catch-statements-1

mfariacastro avatar Mar 28 '23 14:03 mfariacastro