Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Taking nested gradient for implementing Wasserstein GAN with gradient penalty (WGAN-GP) on GPU

Open bhatiaabhinav opened this issue 3 years ago β€’ 9 comments

I am trying to implement WGAN-GP using Flux and Zygote. My implementation works fine on CPU but fails on GPU with error LoadError: this intrinsic must be compiled to be called.. I have read that taking nested gradients in Zygote is a mess right now and one needs to use ForwardDiff for that. I have spent hours surfing through relevant issues, but I can't figure how to do that in my case.

Here is my brief implantation, which is almost line by line identical to the pseudocode in the original paper.

using Flux
using Flux: update!, params
using Zygote
using StatsBase

"""
WGAN with gradient penalty. See algorithm 1 in https://proceedings.neurips.cc/paper/2017/file/892c3b1c6dccd52936e27cbd0ff683d6-Paper.pdf. The following code is almost line by line identical.
"""
function train_WGAN_GP(𝐺, 𝐷, 𝐗::Array{Float32, N}, latent_size, num_iters, device_fn; m=32, Ξ»=10f0, ncritic=5, Ξ±=0.0001, β₁=0, Ξ²β‚‚=0.9) where N
    n = size(𝐗)[end]    # length of dataset
    𝐺, 𝐷 = device_fn(deepcopy(𝐺)), device_fn(deepcopy(𝐷))
    ΞΈ, 𝑀 = params(𝐺), params(𝐷)
    adamΞΈ, adam𝑀 = ADAM(Ξ±, (β₁, Ξ²β‚‚)), ADAM(Ξ±, (β₁, Ξ²β‚‚)) 

    for iter in 1:num_iters
        for t in 1:ncritic
            𝐱, 𝐳, π›œ = 𝐗[repeat([:], N-1)..., rand(1:n, m)], randn(Float32, latent_size..., m), rand(Float32, repeat([1], N-1)..., m) # Sample a minibatch of real data x, latent variables z, random numbers Ο΅ ∼ U[0, 1].
            𝐱, 𝐳, π›œ = device_fn(𝐱), device_fn(𝐳), device_fn(π›œ)
            𝐱̃ = 𝐺(𝐳)
            𝐱̂ = π›œ .* 𝐱 + (1f0 .- π›œ) .* 𝐱̃
            βˆ‡π‘€L = gradient(𝑀) do
                βˆ‡π±Μ‚π·, = gradient(𝐱̂ ->  sum(𝐷(𝐱̂)), 𝐱̂)
                L = mean(𝐷(𝐱̃)) - mean(𝐷(𝐱)) + Ξ» * mean((sqrt.(sum(βˆ‡π±Μ‚π·.^2, dims=1) .+ 1f-12) .- 1f0).^2)
            end
            update!(adam𝑀, 𝑀, βˆ‡π‘€L)
        end

        𝐳 = device_fn(randn(Float32, latent_size..., m))
        βˆ‡ΞΈπ· = gradient(ΞΈ) do
            -mean(𝐷(𝐺(𝐳)))
        end
        update!(adamΞΈ, ΞΈ, βˆ‡ΞΈπ·)
    end

    return 𝐺, 𝐷
end

𝐗 = rand(Float32, 50, 10000)  # dummy dataset
z = 16                        # latent size
𝐺 = Chain(Dense(z, 32, leakyrelu), Dense(32, 50))   # Generator
𝐷 = Chain(Dense(50, 32, leakyrelu), Dense(32, 1))   # Critic

𝐺, 𝐷 = train_WGAN_GP(𝐺, 𝐷, 𝐗, (z, ), 1, cpu) # works
𝐺, 𝐷 = train_WGAN_GP(𝐺, 𝐷, 𝐗, (z, ), 1, gpu) # fails

This fails at line βˆ‡π±Μ‚π·, = gradient(𝐱̂ -> sum(𝐷(𝐱̂)), 𝐱̂) on GPU with error: LoadError: this intrinsic must be compiled to be called.

Can anyone help me?

Here is a code snippet that isolates the problem:

using Flux
using Statistics  # [edit: does not need StatsBase]

function run_isolated_code_on(device_fn)
    D = Chain(Dense(5, 3, leakyrelu), Dense(3, 1)) |> device_fn  # Critic [edit: size was 50 => 32]
    w = Flux.params(D)  # [edit]
    x = rand(Float32, 5, 3) |> device_fn                   # Dummy minibatch
    βˆ‡wL = gradient(w) do
        βˆ‡xD, = gradient(x ->  sum(D(x)), x)                # The problematic line
        L = mean((sqrt.(sum(βˆ‡xD.^2, dims=1) .+ 1f-12) .- 1f0).^2)   # gradient penalty
    end
end

run_isolated_code_on(cpu)  # works
run_isolated_code_on(gpu)  # fails

bhatiaabhinav avatar Jul 13 '22 05:07 bhatiaabhinav

Can you provide the full error and stacktrace? The example could probably also be distilled down a more, but we can deal with that later.

ToucheSir avatar Jul 13 '22 05:07 ToucheSir

Super basic code: image Stack trace:

ERROR: LoadError: this intrinsic must be compiled to be called
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0 [inlined]
  [2] _pullback(::Zygote.Context, ::Core.IntrinsicFunction, ::String, ::Type{Int64}, ::Type{Tuple{Ptr{Int64}}}, ::Ptr{Int64})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:9
  [3] _pullback
    @ ./atomics.jl:358 [inlined]
  [4] _pullback(ctx::Zygote.Context, f::typeof(getindex), args::Base.Threads.Atomic{Int64})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
  [5] _pullback (repeats 2 times)
    @ ~/.julia/packages/CUDA/tTK8Y/lib/utils/threading.jl:25 [inlined]
  [6] _pullback
    @ ~/.julia/packages/CUDA/tTK8Y/src/compiler/gpucompiler.jl:7 [inlined]
  [7] _pullback(ctx::Zygote.Context, f::typeof(CUDA.device_properties), args::CUDA.CuDevice)
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
  [8] _pullback
    @ ~/.julia/packages/CUDA/tTK8Y/src/compiler/gpucompiler.jl:51 [inlined]
  [9] _pullback(::Zygote.Context, ::CUDA.var"##CUDACompilerTarget#206", ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(CUDA.CUDACompilerTarget), ::CUDA.CuDevice)
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [10] _pullback
    @ ~/.julia/packages/CUDA/tTK8Y/src/compiler/gpucompiler.jl:51 [inlined]
 [11] _pullback
    @ ~/.julia/packages/CUDA/tTK8Y/src/compiler/execution.jl:296 [inlined]
 [12] _pullback(::Zygote.Context, ::CUDA.var"##cufunction#221", ::Nothing, ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(CUDA.cufunction), ::GPUArrays.var"#broadcast_kernel#15", ::Type{Tuple{CUDA.CuKernelContext, CUDA.CuDeviceMatrix{Float32, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}}, Int64}})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [13] _pullback
    @ ~/.julia/packages/CUDA/tTK8Y/src/compiler/execution.jl:293 [inlined]
 [14] _pullback(::Zygote.Context, ::typeof(CUDA.cufunction), ::GPUArrays.var"#broadcast_kernel#15", ::Type{Tuple{CUDA.CuKernelContext, CUDA.CuDeviceMatrix{Float32, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}}, Int64}})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [15] macro expansion
    @ ~/.julia/packages/CUDA/tTK8Y/src/compiler/execution.jl:102 [inlined]
 [16] _pullback
    @ ~/.julia/packages/CUDA/tTK8Y/src/gpuarrays.jl:17 [inlined]
 [17] _pullback(::Zygote.Context, ::CUDA.var"##launch_heuristic#248", ::Int64, ::Int64, ::typeof(GPUArrays.launch_heuristic), ::CUDA.CuArrayBackend, ::GPUArrays.var"#broadcast_kernel#15", ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}}, ::Int64)
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [18] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:814
 [19] adjoint
    @ ~/.julia/packages/Zygote/DkIUK/src/lib/lib.jl:204 [inlined]
 [20] adjoint(::Zygote.Context, ::typeof(Core._apply_iterate), ::typeof(iterate), ::Function, ::Tuple{Int64, Int64, typeof(GPUArrays.launch_heuristic), CUDA.CuArrayBackend, GPUArrays.var"#broadcast_kernel#15"}, ::Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}}, Int64})
    @ Zygote ./none:0
 [21] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [22] _pullback
    @ ~/.julia/packages/CUDA/tTK8Y/src/gpuarrays.jl:17 [inlined]
 [23] _pullback(::Zygote.Context, ::GPUArrays.var"#launch_heuristic##kw", ::NamedTuple{(:elements, :elements_per_thread), Tuple{Int64, Int64}}, ::typeof(GPUArrays.launch_heuristic), ::CUDA.CuArrayBackend, ::GPUArrays.var"#broadcast_kernel#15", ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}}, ::Int64)
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [24] _pullback
    @ ~/.julia/packages/GPUArrays/EVTem/src/host/broadcast.jl:73 [inlined]
 [25] _pullback(::Zygote.Context, ::typeof(GPUArrays._copyto!), ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [26] _pullback
    @ ~/.julia/packages/GPUArrays/EVTem/src/host/broadcast.jl:51 [inlined]
 [27] _pullback
    @ ./broadcast.jl:868 [inlined]
 [28] _pullback(::Zygote.Context, ::typeof(Base.Broadcast.materialize!), ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [29] _pullback
    @ ./broadcast.jl:864 [inlined]
 [30] _pullback
    @ ~/.julia/packages/Zygote/DkIUK/src/lib/broadcast.jl:280 [inlined]
 [31] _pullback(ctx::Zygote.Context, f::Zygote.var"#1225#1232"{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [32] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [33] _pullback(ctx::Zygote.Context, f::Zygote.var"#4043#back#1233"{Zygote.var"#1225#1232"{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [34] _pullback
    @ ~/workspace/wgan-gp.jl:11 [inlined]
 [35] _pullback(ctx::Zygote.Context, f::typeof(βˆ‚(Ξ»)), args::Float32)
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [36] _pullback
    @ ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:41 [inlined]
 [37] _pullback(ctx::Zygote.Context, f::Zygote.var"#52#53"{typeof(βˆ‚(Ξ»))}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [38] _pullback
    @ ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:76 [inlined]
 [39] _pullback(::Zygote.Context, ::typeof(gradient), ::var"#30#32"{Dense{typeof(identity), CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [40] _pullback
    @ ~/workspace/wgan-gp.jl:11 [inlined]
 [41] _pullback(::Zygote.Context, ::var"#29#31"{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Dense{typeof(identity), CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [42] pullback(f::Function, ps::Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:352
 [43] gradient(f::Function, args::Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:75
 [44] run_isolated_code_on(device_fn::Function)
    @ Main ~/workspace/wgan-gp.jl:10
 [45] top-level scope
    @ ~/workspace/wgan-gp.jl:16
 [46] include(fname::String)
    @ Base.MainInclude ./client.jl:451
 [47] top-level scope
    @ REPL[1]:1
 [48] top-level scope
    @ ~/.julia/packages/CUDA/tTK8Y/src/initialization.jl:52
in expression starting at /home/abhinav/workspace/wgan-gp.jl:16

bhatiaabhinav avatar Jul 13 '22 06:07 bhatiaabhinav

Thanks. Getting rid of the not twice-differentiable https://github.com/FluxML/Zygote.jl/blob/master/src/lib/broadcast.jl#L278-L281 gets us a step closer:

ERROR: LoadError: MethodError: objects of type Matrix{Float32} are not callable
Use square brackets [] for indexing an Array.
Stacktrace:
  [1] macro expansion
    @ ~/.julia/dev/Zygote/src/compiler/interface2.jl:0 [inlined]
  [2] _pullback(ctx::Zygote.Context, f::Matrix{Float32}, args::FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:9
  [3] _pullback
    @ ~/.julia/dev/Zygote/src/compiler/chainrules.jl:183 [inlined]
  [4] _pullback
    @ ~/.julia/dev/Zygote/src/lib/broadcast.jl:51 [inlined]
  [5] _pullback(::Zygote.Context, ::typeof(Zygote.unbroadcast), ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
  [6] _pullback
    @ ~/.julia/dev/Zygote/src/lib/broadcast.jl:75 [inlined]
  [7] _pullback(ctx::Zygote.Context, f::Zygote.var"#917#919"{FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}, args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
  [8] (::Zygote.var"#550#554"{Zygote.Context, Zygote.var"#917#919"{FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}})(args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/dev/Zygote/src/lib/array.jl:195
  [9] map
    @ ./tuple.jl:222 [inlined]
 [10] βˆ‡map(cx::Zygote.Context, f::Zygote.var"#917#919"{FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}, args::Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}})
    @ Zygote ~/.julia/dev/Zygote/src/lib/array.jl:195
 [11] adjoint
    @ ~/.julia/dev/Zygote/src/lib/array.jl:221 [inlined]
 [12] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [13] _pullback
    @ ~/.julia/dev/Zygote/src/lib/broadcast.jl:75 [inlined]
 [14] _pullback(ctx::Zygote.Context, f::Zygote.var"#916#918"{Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, args::FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [15] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [16] _pullback(ctx::Zygote.Context, f::Zygote.var"#3688#back#920"{Zygote.var"#916#918"{Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, args::FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [17] _pullback
    @ ~/.julia/packages/Flux/KkC79/src/layers/basic.jl:172 [inlined]
 [18] _pullback(ctx::Zygote.Context, f::Zygote.Pullback{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.ZBack{ChainRules.var"#times_pullback#1401"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(NNlib.fast_act), typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#1920#back#226"{Zygote.var"#back#225"{:Οƒ, Zygote.Context, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(identity)}}, Zygote.var"#1920#back#226"{Zygote.var"#back#225"{:bias, Zygote.Context, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#3688#back#920"{Zygote.var"#916#918"{Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Zygote.var"#1920#back#226"{Zygote.var"#back#225"{:weight, Zygote.Context, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.var"#3784#back#946"{Zygote.var"#944#945"}}}, args::FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [19] _pullback
    @ ~/projects/juliamwes/fluxstuff/gradient-penalty.jl:6 [inlined]
 [20] _pullback(ctx::Zygote.Context, f::Zygote.Pullback{Tuple{var"#2#4", CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.var"#1766#back#157"{Zygote.var"#155#156"{Zygote.Context, GlobalRef, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Zygote.Pullback{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.ZBack{ChainRules.var"#times_pullback#1401"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(NNlib.fast_act), typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#1920#back#226"{Zygote.var"#back#225"{:Οƒ, Zygote.Context, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(identity)}}, Zygote.var"#1920#back#226"{Zygote.var"#back#225"{:bias, Zygote.Context, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#3688#back#920"{Zygote.var"#916#918"{Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Zygote.var"#1920#back#226"{Zygote.var"#back#225"{:weight, Zygote.Context, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.var"#3784#back#946"{Zygote.var"#944#945"}}}, Zygote.var"#2865#back#621"{Zygote.var"#617#619"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}}}, args::Float32)
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [21] _pullback
    @ ~/.julia/dev/Zygote/src/compiler/interface.jl:41 [inlined]
 [22] _pullback(ctx::Zygote.Context, f::Zygote.var"#60#61"{Zygote.Pullback{Tuple{var"#2#4", CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.var"#1766#back#157"{Zygote.var"#155#156"{Zygote.Context, GlobalRef, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Zygote.Pullback{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.ZBack{ChainRules.var"#times_pullback#1401"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(NNlib.fast_act), typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#1920#back#226"{Zygote.var"#back#225"{:Οƒ, Zygote.Context, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(identity)}}, Zygote.var"#1920#back#226"{Zygote.var"#back#225"{:bias, Zygote.Context, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#3688#back#920"{Zygote.var"#916#918"{Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Zygote.var"#1920#back#226"{Zygote.var"#back#225"{:weight, Zygote.Context, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.var"#3784#back#946"{Zygote.var"#944#945"}}}, Zygote.var"#2865#back#621"{Zygote.var"#617#619"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}}}}, args::Float32)
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [23] _pullback
    @ ~/.julia/dev/Zygote/src/compiler/interface.jl:76 [inlined]
 [24] _pullback(::Zygote.Context, ::typeof(gradient), ::var"#2#4", ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [25] _pullback
    @ ~/projects/juliamwes/fluxstuff/gradient-penalty.jl:6 [inlined]
 [26] _pullback(::Zygote.Context, ::var"#1#3")
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [27] pullback(f::Function, ps::Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:352
 [28] gradient(f::Function, args::Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:75
 [29] top-level scope
    @ ~/projects/juliamwes/fluxstuff/gradient-penalty.jl:5

Barebones MWE:

using Flux, CUDA, Zygote

D = Dense(5, 1) |> gpu
x = CUDA.rand(5, 2)
gradient(Flux.params(D)) do
  βˆ‡xD, = gradient(x -> sum(D(x)), x)  # [edit, unpack as suggested below]
  sum(βˆ‡xD)
end

The next challenge is that the broadcast pullback here isn't differentiable. IIUC we probably want to intercept this at unbroadcast and not let it try to differentiate all the way through to ProjectTo here?

ToucheSir avatar Jul 14 '22 03:07 ToucheSir

Good digging.

The sum pullback is awful. The one for ordinary arrays makes a Fill, which IIRC caused some weird interactions with CuArrays... we should consider simply deleting that completely. (Edit: surprises like #1269 are another reason to delete it.) I think the ChainRules one is 2nd differentiable.

For broadcasting & projection, one idea would be to just add a method like this:

_project(x::DenseArray{T}, dx::AbstractArray{T}) where {T<:Number} = dx

mcabbott avatar Jul 14 '22 13:07 mcabbott

@ToucheSir, there is a minor bug in your MWE code. βˆ‡xD = gradient(x -> sum(D(x)), x) should be βˆ‡xD, = gradient(x -> sum(D(x)), x). Notice the comma after βˆ‡xD. Or you could do βˆ‡xD = gradient(x -> sum(D(x)), x)[1].

bhatiaabhinav avatar Jul 14 '22 13:07 bhatiaabhinav

Fixed, thanks (GH doesn't scale up linked images well for larger screens, which is why we recommend code blocks/gists for MWEs).

With the _project method addition, we can get all the way back to the * in Dense:

ERROR: LoadError: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] assertscalar(op::String)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/rSIl2/src/GPUArraysCore.jl:78
  [3] getindex(::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Int64, ::Int64)
    @ GPUArrays ~/.julia/packages/GPUArrays/gok9K/src/host/indexing.jl:9
  [4] _generic_matmatmul!(C::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, tA::Char, tB::Char, A::FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, B::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
    @ LinearAlgebra ~/.julia/juliaup/julia-1.7.3+0~x64/share/julia/stdlib/v1.7/LinearAlgebra/src/matmul.jl:830
  [5] generic_matmatmul!(C::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, tA::Char, tB::Char, A::FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, B::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
    @ LinearAlgebra ~/.julia/juliaup/julia-1.7.3+0~x64/share/julia/stdlib/v1.7/LinearAlgebra/src/matmul.jl:798
  [6] mul!
    @ ~/.julia/juliaup/julia-1.7.3+0~x64/share/julia/stdlib/v1.7/LinearAlgebra/src/matmul.jl:478 [inlined]
  [7] mul!
    @ ~/.julia/juliaup/julia-1.7.3+0~x64/share/julia/stdlib/v1.7/LinearAlgebra/src/matmul.jl:275 [inlined]
  [8] *
    @ ~/.julia/juliaup/julia-1.7.3+0~x64/share/julia/stdlib/v1.7/LinearAlgebra/src/matmul.jl:153 [inlined]
  [9] rrule
    @ ~/.julia/packages/ChainRules/oot29/src/rulesets/Base/arraymath.jl:40 [inlined]
 [10] rrule
    @ ~/.julia/packages/ChainRulesCore/16PWJ/src/rules.jl:134 [inlined]
 [11] chain_rrule
    @ ~/.julia/dev/Zygote/src/compiler/chainrules.jl:219 [inlined]
 [12] macro expansion
    @ ~/.julia/dev/Zygote/src/compiler/interface2.jl:0 [inlined]
 [13] _pullback
    @ ~/.julia/dev/Zygote/src/compiler/interface2.jl:9 [inlined]
 [14] _pullback
    @ ~/.julia/packages/ChainRules/oot29/src/rulesets/Base/arraymath.jl:56 [inlined]
 [15] _pullback
    @ ~/.julia/packages/ChainRulesCore/16PWJ/src/tangent_types/thunks.jl:195 [inlined]
 [16] _pullback
    @ ~/.julia/packages/ChainRulesCore/16PWJ/src/tangent_types/thunks.jl:222 [inlined]
 [17] _pullback
    @ ~/.julia/dev/Zygote/src/compiler/chainrules.jl:104 [inlined]
 [18] #550
    @ ~/.julia/dev/Zygote/src/lib/array.jl:195 [inlined]
 [19] map
    @ ./tuple.jl:223 [inlined]
 [20] βˆ‡map(cx::Zygote.Context, f::typeof(Zygote.wrap_chainrules_output), args::Tuple{ChainRulesCore.NoTangent, ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1398#1403"{FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, ChainRules.var"#1397#1402"{FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1400#1405"{FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, ChainRules.var"#1399#1404"{FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}})
    @ Zygote ~/.julia/dev/Zygote/src/lib/array.jl:195
 [21] adjoint
    @ ~/.julia/dev/Zygote/src/lib/array.jl:221 [inlined]
 [22] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [23] _pullback
    @ ~/.julia/dev/Zygote/src/compiler/chainrules.jl:105 [inlined]
 [24] _pullback
    @ ~/.julia/dev/Zygote/src/compiler/chainrules.jl:207 [inlined]
 [25] _pullback
    @ ~/.julia/packages/Flux/KkC79/src/layers/basic.jl:172 [inlined]
 [26] _pullback(ctx::Zygote.Context, f::Zygote.Pullback{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.var"#3784#back#946"{Zygote.var"#944#945"}, Zygote.var"#3688#back#920"{Zygote.var"#916#918"{Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1401"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.var"#1920#back#226"{Zygote.var"#back#225"{:weight, Zygote.Context, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#1920#back#226"{Zygote.var"#back#225"{:bias, Zygote.Context, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(NNlib.fast_act), typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#1920#back#226"{Zygote.var"#back#225"{:Οƒ, Zygote.Context, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(identity)}}}}, args::FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [27] _pullback
    @ ~/projects/juliamwes/fluxstuff/gradient-penalty.jl:6 [inlined]
 [28] _pullback(ctx::Zygote.Context, f::Zygote.Pullback{Tuple{var"#2#4", CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.var"#2865#back#621"{Zygote.var"#617#619"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.var"#3784#back#946"{Zygote.var"#944#945"}, Zygote.var"#3688#back#920"{Zygote.var"#916#918"{Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1401"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.var"#1920#back#226"{Zygote.var"#back#225"{:weight, Zygote.Context, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#1920#back#226"{Zygote.var"#back#225"{:bias, Zygote.Context, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(NNlib.fast_act), typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#1920#back#226"{Zygote.var"#back#225"{:Οƒ, Zygote.Context, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(identity)}}}}, Zygote.var"#1766#back#157"{Zygote.var"#155#156"{Zygote.Context, GlobalRef, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}}, args::Float32)
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [29] _pullback
    @ ~/.julia/dev/Zygote/src/compiler/interface.jl:41 [inlined]
 [30] _pullback(ctx::Zygote.Context, f::Zygote.var"#60#61"{Zygote.Pullback{Tuple{var"#2#4", CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.var"#2865#back#621"{Zygote.var"#617#619"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.var"#3784#back#946"{Zygote.var"#944#945"}, Zygote.var"#3688#back#920"{Zygote.var"#916#918"{Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1401"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.var"#1920#back#226"{Zygote.var"#back#225"{:weight, Zygote.Context, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#1920#back#226"{Zygote.var"#back#225"{:bias, Zygote.Context, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(NNlib.fast_act), typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#1920#back#226"{Zygote.var"#back#225"{:Οƒ, Zygote.Context, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(identity)}}}}, Zygote.var"#1766#back#157"{Zygote.var"#155#156"{Zygote.Context, GlobalRef, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}}}, args::Float32)
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [31] _pullback
    @ ~/.julia/dev/Zygote/src/compiler/interface.jl:76 [inlined]
 [32] _pullback(::Zygote.Context, ::typeof(gradient), ::var"#2#4", ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [33] _pullback
    @ ~/projects/juliamwes/fluxstuff/gradient-penalty.jl:6 [inlined]
 [34] _pullback(::Zygote.Context, ::var"#1#3")
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [35] pullback(f::Function, ps::Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:352
 [36] gradient(f::Function, args::Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:75
 [37] top-level scope
    @ ~/projects/juliamwes/fluxstuff/gradient-penalty.jl:5

I'm not sure why it hits the fallback matmul instead of https://github.com/JuliaArrays/FillArrays.jl/blob/master/src/fillalgebra.jl#L117, as the eltypes look the same and CuMatrix <: StridedMatrix.

Removing the array sum adjoint entirely gives us this:

ERROR: LoadError: MethodError: objects of type Matrix{Float32} are not callable
Use square brackets [] for indexing an Array.
Stacktrace:
  [1] macro expansion
    @ ~/.julia/dev/Zygote/src/compiler/interface2.jl:0 [inlined]
  [2] _pullback(ctx::Zygote.Context, f::Matrix{Float32}, args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:9
  [3] _pullback
    @ ~/.julia/packages/ChainRules/oot29/src/rulesets/Base/mapreduce.jl:25 [inlined]
  [4] _pullback(::Zygote.Context, ::ChainRules.var"#1548#1551"{Float32, Colon, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Matrix{Float32}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
  [5] _pullback
    @ ~/.julia/packages/ChainRulesCore/16PWJ/src/tangent_types/thunks.jl:195 [inlined]
  [6] _pullback
    @ ~/.julia/packages/ChainRulesCore/16PWJ/src/tangent_types/thunks.jl:222 [inlined]
  [7] _pullback(ctx::Zygote.Context, f::typeof(ChainRulesCore.unthunk), args::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1548#1551"{Float32, Colon, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Matrix{Float32}}}, ChainRules.var"#1547#1550"{Float32, Colon}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
  [8] _pullback
    @ ~/.julia/dev/Zygote/src/compiler/chainrules.jl:104 [inlined]
  [9] (::Zygote.var"#550#554"{Zygote.Context, typeof(Zygote.wrap_chainrules_output)})(args::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1548#1551"{Float32, Colon, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Matrix{Float32}}}, ChainRules.var"#1547#1550"{Float32, Colon}})
    @ Zygote ~/.julia/dev/Zygote/src/lib/array.jl:195
 [10] map
    @ ./tuple.jl:222 [inlined]
 [11] βˆ‡map(cx::Zygote.Context, f::typeof(Zygote.wrap_chainrules_output), args::Tuple{ChainRulesCore.NoTangent, ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1548#1551"{Float32, Colon, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Matrix{Float32}}}, ChainRules.var"#1547#1550"{Float32, Colon}}})
    @ Zygote ~/.julia/dev/Zygote/src/lib/array.jl:195
 [12] adjoint
    @ ~/.julia/dev/Zygote/src/lib/array.jl:221 [inlined]
 [13] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [14] _pullback
    @ ~/.julia/dev/Zygote/src/compiler/chainrules.jl:105 [inlined]
 [15] _pullback
    @ ~/.julia/dev/Zygote/src/compiler/chainrules.jl:207 [inlined]
 [16] _pullback
    @ ~/projects/juliamwes/fluxstuff/gradient-penalty.jl:8 [inlined]
 [17] _pullback(ctx::Zygote.Context, f::Zygote.Pullback{Tuple{var"#2#4"{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.Pullback{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#3772#back#939"{Zygote.var"#937#938"}, Zygote.var"#1920#back#226"{Zygote.var"#back#225"{:bias, Zygote.Context, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1401"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.var"#1920#back#226"{Zygote.var"#back#225"{:weight, Zygote.Context, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.var"#3676#back#913"{Zygote.var"#909#911"{Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Zygote.Pullback{Tuple{typeof(NNlib.fast_act), typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#1920#back#226"{Zygote.var"#back#225"{:Οƒ, Zygote.Context, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(identity)}}}}, Zygote.var"#1920#back#226"{Zygote.var"#back#225"{:D, Zygote.Context, var"#2#4"{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Zygote.ZBack{ChainRules.var"#sum_pullback#1549"{Colon, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Matrix{Float32}}}}}, args::Float32)
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [18] _pullback
    @ ~/.julia/dev/Zygote/src/compiler/interface.jl:41 [inlined]
 [19] _pullback(ctx::Zygote.Context, f::Zygote.var"#60#61"{Zygote.Pullback{Tuple{var"#2#4"{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.Pullback{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#3772#back#939"{Zygote.var"#937#938"}, Zygote.var"#1920#back#226"{Zygote.var"#back#225"{:bias, Zygote.Context, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1401"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.var"#1920#back#226"{Zygote.var"#back#225"{:weight, Zygote.Context, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.var"#3676#back#913"{Zygote.var"#909#911"{Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Zygote.Pullback{Tuple{typeof(NNlib.fast_act), typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#1920#back#226"{Zygote.var"#back#225"{:Οƒ, Zygote.Context, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, typeof(identity)}}}}, Zygote.var"#1920#back#226"{Zygote.var"#back#225"{:D, Zygote.Context, var"#2#4"{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Zygote.ZBack{ChainRules.var"#sum_pullback#1549"{Colon, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Matrix{Float32}}}}}}, args::Float32)
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [20] _pullback
    @ ~/.julia/dev/Zygote/src/compiler/interface.jl:76 [inlined]
 [21] _pullback(::Zygote.Context, ::typeof(gradient), ::var"#2#4"{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [22] _pullback
    @ ~/projects/juliamwes/fluxstuff/gradient-penalty.jl:8 [inlined]
 [23] _pullback(::Zygote.Context, ::var"#1#3"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [24] pullback(f::Function, ps::Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:352
 [25] gradient(f::Function, args::Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:75
 [26] top-level scope
    @ ~/projects/juliamwes/fluxstuff/gradient-penalty.jl:7

This time from https://github.com/JuliaDiff/ChainRules.jl/blob/v1.37.0/src/rulesets/Base/mapreduce.jl#L25. Would it make sense to add rules for certain projector types? I'm still not sure why an array is being treated as a function here.

ToucheSir avatar Jul 16 '22 02:07 ToucheSir

Have you guys figured how to fix this?

PS: I managed to implement WGAN-GP using Knet, but that library is a pain.

bhatiaabhinav avatar Jul 17 '22 16:07 bhatiaabhinav

We are differentiating through cuda.jl code in the broadcasting. That sounds like a missing adjoint. I would break down the broadcast in separate kernels and bisect the issue that way.

On Sun, Jul 17, 2022, 21:31 Abhinav Bhatia @.***> wrote:

Have you guys figured how to fix this?

β€” Reply to this email directly, view it on GitHub https://github.com/FluxML/Zygote.jl/issues/1262#issuecomment-1186552613, or unsubscribe https://github.com/notifications/unsubscribe-auth/AJOZVVIQZ6QOFFPELA6CX2TVUQU47ANCNFSM53NLJAPQ . You are receiving this because you are subscribed to this thread.Message ID: @.***>

DhairyaLGandhi avatar Jul 17 '22 16:07 DhairyaLGandhi

Removing the array sum adjoint entirely gives us this,

ERROR: LoadError: MethodError: objects of type Matrix{Float32} are not callable
Use square brackets [] for indexing an Array.
Stacktrace:
...
 [11] βˆ‡map(cx::Zygote.Context, f::typeof(Zygote.wrap_chainrules_output), args::Tuple{ChainRulesCore.NoTangent, ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1548#1551"{Float32, Colon, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Matrix{Float32}}}, ChainRules.var"#1547#1550"{Float32, Colon}}})
    @ Zygote ~/.julia/dev/Zygote/src/lib/array.jl:195

Notice in the above error that βˆ‡map is applying wrap_chainrules_output, so it's differentiating the code for moving between the Zygote and ChainRules representations. The two conversion rules are inverse to each other, so they should have rules calling each other, something like https://github.com/FluxML/Zygote.jl/commit/931b7466e64e60907eb6647ff6bd035c5b58e168

Using the CR rule for sum avoids the Zygote rule's mutation, and its use of Fill. But that rule contains thunks... maybe someone can write a rule to remove those, this is wrong but something like this may work:

@adjoint function ChainRulesCore.Thunk(f)
  y, bk = pullback(unthunk, @show f)
  return y, dy -> (bk(dy),)
end
@adjoint function ChainRulesCore.InplaceableThunk(f!, g::Function)
  y, bk = pullback(unthunk, @show g)
  return y, dy -> (nothing, bk(dy))
end

Or else all thunks could be removed at 2nd order: https://github.com/JuliaDiff/ChainRulesCore.jl/commit/753f5d542a09984365831314ad9f5b4443c3c566

My changes are in these branches:

]
add https://github.com/mcabbott/Zygote.jl#for1262
add https://github.com/mcabbott/ChainRulesCore.jl#nothunk

This time from https://github.com/JuliaDiff/ChainRules.jl/blob/v1.37.0/src/rulesets/Base/mapreduce.jl#L25.

The projection there still seems to cause problems. Maybe something like rrule(project::ProjectTo, dx) = project(dx), ddx -> (NoTangent(), ddx) can fix them? But removing that, at least one 2nd derivative example runs.

The example from https://github.com/FluxML/Zygote.jl/issues/1262#issuecomment-1183950487 above does not, I think due to some extra issue with global variables & Params etc.

julia> using Flux, CUDA

julia> D = Dense(5, 1) |> gpu;

julia> D.weight
1Γ—5 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
 0.887838  0.863505  0.0667794  -0.808859  0.200366

julia> x = CUDA.rand(5, 2);

# My simplified example without Params, nothing global:

julia> gradient(D, x) do Din, xin
         βˆ‡xD = gradient(x -> sum(Din(x)), xin)[1]
         sum(abs2, βˆ‡xD)
       end
dx2 = 1.0f0
ERROR: MethodError: objects of type Matrix{Float32} are not callable
Use square brackets [] for indexing an Array.
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Zygote/5GQsG/src/compiler/interface2.jl:0 [inlined]
  [2] _pullback(ctx::Zygote.Context, f::Matrix{Float32}, args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/5GQsG/src/compiler/interface2.jl:9
  [3] _pullback
    @ ~/.julia/packages/ChainRules/o1vND/src/rulesets/Base/mapreduce.jl:33 [inlined]
  [4] _pullback(ctx::Zygote.Context, f::ChainRules.var"#sum_pullback#1574"{Colon, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Matrix{Float32}}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/5GQsG/src/compiler/interface2.jl:0
  [5] _pullback
    @ ~/.julia/packages/Zygote/5GQsG/src/compiler/chainrules.jl:208 [inlined]
  [6] _pullback
    @ ./REPL[37]:2 [inlined]
  [7] _pullback(ctx::Zygote.Context, f::typeof(βˆ‚(Ξ»)), args::Float32)
    @ Zygote ~/.julia/packages/Zygote/5GQsG/src/compiler/interface2.jl:0
  [8] _pullback
    @ ~/.julia/packages/Zygote/5GQsG/src/compiler/interface.jl:41 [inlined]
  [9] _pullback(ctx::Zygote.Context, f::Zygote.var"#70#71"{typeof(βˆ‚(Ξ»))}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/5GQsG/src/compiler/interface2.jl:0
 [10] _pullback
    @ ~/.julia/packages/Zygote/5GQsG/src/compiler/interface.jl:76 [inlined]
 [11] _pullback(::Zygote.Context, ::typeof(gradient), ::var"#30#32"{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/5GQsG/src/compiler/interface2.jl:0
 [12] _pullback
    @ ./REPL[37]:2 [inlined]
 [13] _pullback(::Zygote.Context, ::var"#29#31", ::Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/5GQsG/src/compiler/interface2.jl:0
 [14] _pullback(::Function, ::Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/5GQsG/src/compiler/interface.jl:34
 [15] pullback(::Function, ::Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/5GQsG/src/compiler/interface.jl:40
 [16] gradient(::Function, ::Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, ::Vararg{Any})
    @ Zygote ~/.julia/packages/Zygote/5GQsG/src/compiler/interface.jl:75
 [17] top-level scope
    @ REPL[37]:1
 [18] top-level scope
    @ ~/.julia/packages/CUDA/tTK8Y/src/initialization.jl:52

# Here project has been removed. Could for this rule be moved into _unsum, but in general not.

julia> @eval ChainRules function rrule(::typeof(sum), x::AbstractArray; dims=:)
           project = ProjectTo(x)
           y = sum(x; dims=dims)
           function sum_pullback(dy_raw)
               dy = unthunk(dy_raw)
               x_thunk = InplaceableThunk(
                   dx -> dx .+= (dims isa Colon ? Ref(dy) : dy),
                   @thunk _unsum(x, dy, dims)
               )
               return (NoTangent(), x_thunk)
           end
           return y, sum_pullback
       end
rrule (generic function with 1025 methods)

# Then at last one example works:

julia> gradient(D, x) do Din, xin
         βˆ‡xD = gradient(x -> sum(Din(x)), xin)[1]
         sum(abs2, βˆ‡xD)
       end
dx2 = 1.0f0
dx3 = (weight = nothing, bias = Float32[2.0;;], Οƒ = nothing)
...
  2 dx = Float32[1.0 1.0] => ddx = Float32[4.4655113 4.4655113]
  2 dx = 1.0 => ddx = 8.931023
# (debugging printout!)
((weight = Float32[3.5513535 3.454018 … -3.2354355 0.8014631], bias = nothing, Οƒ = nothing), nothing)

# This more complicated variant works too:

julia> CUDA.allowscalar(false)

julia> C = Chain(Dense(5, 2, tanh), Dense(2,1)) |> gpu;

julia> gradient(C, x) do Cin, xin
         βˆ‡xD = gradient(x -> sum(abs2, Cin(x .+ x)), xin)[1]
         sum(abs2, βˆ‡xD)
       end
((layers = ((weight = Float32[-2.6709735 -1.9108372 … 0.11256764 -3.5999453; -0.59868777 -0.43379757 … 0.02838762 -0.8036242], bias = Float32[-1.6558683; -0.36147818;;], Οƒ = nothing), (weight = Float32[-2.3192203 2.4345593], bias = Float32[2.885524;;], Οƒ = nothing)),), Float32[1.2411578 -3.0868912; -0.5612318 1.3269155; … ; 0.15480563 -0.47065437; -1.3156695 3.3637261])

# This is @ToucheSir's example from https://github.com/FluxML/Zygote.jl/issues/1262#issuecomment-1183950487

julia> gradient(Flux.params(D)) do
         βˆ‡xD, = gradient(x -> sum(D(x)), x)
         sum(βˆ‡xD)
       end
dx2 = 1.0f0
dx3 = (weight = nothing, bias = Float32[2.0;;], Οƒ = nothing)
primal = Dense(5 => 1)
dx1 = Tangent{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Flo...}(weight = NoTangent(), bias = Float32[2.0;;], Οƒ = NoTangent())
dx2 = Float32[1.0 1.0]
dx3 = (weight = Float32[1.2584376 0.5326019 0.5429492 0.08476076 1.4062281], bias = nothing, Οƒ = nothing)
primal = Dense(5 => 1)
dx1 = Tangent{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Flo...}(weight = Float32[1.2584376 0.5326019 0.5429492 0.08476076 1.4062281], bias = NoTangent(), Οƒ = NoTangent())
ERROR: Can't differentiate foreigncall expression.
You might want to check the Zygote limitations documentation.
https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] Pullback
    @ ./iddict.jl:102 [inlined]
  [3] (::typeof(βˆ‚(get)))(Ξ”::Nothing)
    @ Zygote ./compiler/interface2.jl:0
  [4] Pullback
    @ ~/.julia/packages/Zygote/5GQsG/src/lib/lib.jl:68 [inlined]
  [5] (::typeof(βˆ‚(accum_global)))(Ξ”::Nothing)
    @ Zygote ./compiler/interface2.jl:0
  [6] Pullback
    @ ~/.julia/packages/Zygote/5GQsG/src/lib/lib.jl:79 [inlined]
  [7] (::typeof(βˆ‚(Ξ»)))(Ξ”::Nothing)
    @ Zygote ./compiler/interface2.jl:0
  [8] Pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
  [9] (::typeof(βˆ‚(Ξ»)))(Ξ”::Nothing)
    @ Zygote ./compiler/interface2.jl:0
 [10] getindex
    @ ./tuple.jl:29 [inlined]
 [11] map
    @ ./tuple.jl:222 [inlined]
 [12] unthunk_tangent
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:36 [inlined]
 [13] #1841#back
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [14] (::typeof(βˆ‚(Ξ»)))(Ξ”::Tuple{Nothing, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}})
    @ Zygote ./compiler/interface2.jl:0
 [15] Pullback
    @ ~/.julia/packages/Zygote/5GQsG/src/compiler/interface.jl:41 [inlined]
 [16] (::typeof(βˆ‚(Ξ»)))(Ξ”::Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}})
    @ Zygote ./compiler/interface2.jl:0
 [17] Pullback
    @ ~/.julia/packages/Zygote/5GQsG/src/compiler/interface.jl:76 [inlined]
 [18] (::typeof(βˆ‚(gradient)))(Ξ”::Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}})
    @ Zygote ./compiler/interface2.jl:0
 [19] Pullback
    @ ./REPL[51]:2 [inlined]
 [20] (::typeof(βˆ‚(#45)))(Ξ”::Float32)
    @ Zygote ./compiler/interface2.jl:0
 [21] (::Zygote.var"#107#108"{Params{Zygote.Buffer{Any, Vector{Any}}}, typeof(βˆ‚(#45)), Zygote.Context})(Ξ”::Float32)
    @ Zygote ~/.julia/packages/Zygote/5GQsG/src/compiler/interface.jl:357
 [22] gradient(f::Function, args::Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote ~/.julia/packages/Zygote/5GQsG/src/compiler/interface.jl:76
 [23] top-level scope


# I thought example from https://github.com/FluxML/Zygote.jl/issues/1262#issue-1302908406 failed,
# but right now it runs...

julia> using Statistics  # [edit: does not need StatsBase]

julia> function run_isolated_code_on(device_fn)
           D = Chain(Dense(5, 3, leakyrelu), Dense(3, 1)) |> device_fn  # Critic [edit: size was 50 => 32]
           w = Flux.params(D)  # [edit]
           x = rand(Float32, 5, 3) |> device_fn                   # Dummy minibatch
           βˆ‡wL = gradient(w) do
               βˆ‡xD, = gradient(x ->  sum(D(x)), x)                # The problematic line
               L = mean((sqrt.(sum(βˆ‡xD.^2, dims=1) .+ 1f-12) .- 1f0).^2)   # gradient penalty
           end
       end
run_isolated_code_on (generic function with 1 method)

julia> run_isolated_code_on(cpu).grads 
dx2 = 1.0f0
dx3 = (weight = nothing, bias = Float32[3.0;;], Οƒ = nothing)
primal = Dense(3 => 1)
dx1 = Tangent{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}(weight = NoTangent(), bias = Float32[3.0;;], Οƒ = NoTangent())
dx2 = Float32[1.0 1.0 1.0]
dx3 = (weight = Float32[0.6786697 0.19200592 1.4292961], bias = nothing, Οƒ = nothing)
primal = Dense(3 => 1)
dx1 = Tangent{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}(weight = Float32[0.6786697 0.19200592 1.4292961], bias = NoTangent(), Οƒ = NoTangent())
dx2 = Float32[-0.12526971 -0.12526971 -0.12526971; 0.92476416 0.92476416 0.92476416; 0.79502344 0.79502344 0.79502344]
dx3 = (weight = nothing, bias = Float32[-0.37580913; 0.9432594; 2.3850703;;], Οƒ = nothing)
primal = Dense(5 => 3, leakyrelu)
dx1 = Tangent{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float32}}}(weight = NoTangent(), bias = Float32[-0.37580913; 0.9432594; 2.3850703;;], Οƒ = NoTangent())
dx2 = Float32[-0.12526971 -0.12526971 -0.12526971; 0.009247641 0.009247641 0.92476416; 0.79502344 0.79502344 0.79502344]
dx3 = (weight = Float32[-0.20781982 -0.12835869 -0.25800195 -0.15170442 -0.18545659; 0.15666027 0.119531386 0.5309863 0.13393076 0.75968456; 1.3189272 0.81462765 1.6374078 0.96279114 1.1769991], bias = nothing, Οƒ = nothing)
primal = Dense(5 => 3, leakyrelu)
dx1 = Tangent{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float32}}}(weight = Float32[-0.20781982 -0.12835869 -0.25800195 -0.15170442 -0.18545659; 0.15666027 0.119531386 0.5309863 0.13393076 0.75968456; 1.3189272 0.81462765 1.6374078 0.96279114 1.1769991], bias = NoTangent(), Οƒ = NoTangent())
dx3 = (layers = ((weight = Float32[-0.20781982 -0.12835869 -0.25800195 -0.15170442 -0.18545659; 0.15666027 0.119531386 0.5309863 0.13393076 0.75968456; 1.3189272 0.81462765 1.6374078 0.96279114 1.1769991], bias = Float32[-0.37580913; 0.9432594; 2.3850703;;], Οƒ = nothing), (weight = Float32[0.6786697 0.19200592 1.4292961], bias = Float32[3.0;;], Οƒ = nothing)),)
primal = Chain(Dense(5 => 3, leakyrelu), Dense(3 => 1))
dx1 = Tangent{Chain{Tuple{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float32}}, Dense{t...}(layers = Tangent{Tuple{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(...}(Tangent{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float32}}}(weight = Float32[-0.20781982 -0.12835869 -0.25800195 -0.15170442 -0.18545659; 0.15666027 0.119531386 0.5309863 0.13393076 0.75968456; 1.3189272 0.81462765 1.6374078 0.96279114 1.1769991], bias = Float32[-0.37580913; 0.9432594; 2.3850703;;], Οƒ = NoTangent()), Tangent{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}(weight = Float32[0.6786697 0.19200592 1.4292961], bias = Float32[3.0;;], Οƒ = NoTangent())),)
dx3 = (D = (layers = ((weight = Float32[-0.20781982 -0.12835869 -0.25800195 -0.15170442 -0.18545659; 0.15666027 0.119531386 0.5309863 0.13393076 0.75968456; 1.3189272 0.81462765 1.6374078 0.96279114 1.1769991], bias = Float32[-0.37580913; 0.9432594; 2.3850703;;], Οƒ = nothing), (weight = Float32[0.6786697 0.19200592 1.4292961], bias = Float32[3.0;;], Οƒ = nothing)),),)
primal = var"#50#52"{Chain{Tuple{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}(Chain(Dense(5 => 3, leakyrelu), Dense(3 => 1)))
dx1 = Tangent{var"#50#52"{Chain{Tuple{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float3...}(D = Tangent{Chain{Tuple{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float32}}, Dense{t...}(layers = Tangent{Tuple{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(...}(Tangent{Dense{typeof(leakyrelu), Matrix{Float32}, Vector{Float32}}}(weight = Float32[-0.20781982 -0.12835869 -0.25800195 -0.15170442 -0.18545659; 0.15666027 0.119531386 0.5309863 0.13393076 0.75968456; 1.3189272 0.81462765 1.6374078 0.96279114 1.1769991], bias = Float32[-0.37580913; 0.9432594; 2.3850703;;], Οƒ = NoTangent()), Tangent{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}(weight = Float32[0.6786697 0.19200592 1.4292961], bias = Float32[3.0;;], Οƒ = NoTangent())),),)
  2 dx = Float32[-0.12526971 -0.12526971 -0.12526971; 0.009247641 0.009247641 0.92476416; 0.79502344 0.79502344 0.79502344] => ddx = Float32[0.06965187 0.06965187 -0.0055287923; 0.21922433 0.21922433 -0.015013097; -0.18856493 -0.18856493 0.0010649858]
  2 dx = Float32[-0.12526971 -0.12526971 -0.12526971; 0.92476416 0.92476416 0.92476416; 0.79502344 0.79502344 0.79502344] => ddx = Float32[0.06965187 0.06965187 -0.0055287923; 0.0021922432 0.0021922432 -0.015013097; -0.18856493 -0.18856493 0.0010649858]
  2 dx = Float32[1.0 1.0 1.0] => ddx = Float32[-0.1566115 -0.1566115 -0.012344295]
  2 dx = 1.0 => ddx = -0.3255673
IdDict{Any, Any} with 5 entries:
  Float32[0.525176 0.0750574 … 0.… => Float32[-0.0153391 -0.0110689 … 0.0481348 0.00963335; 0.00363…
  Float32[0.0]                     => nothing
  Float32[-0.12527 0.924764 0.795… => Float32[0.133775 -0.0106286 -0.376065]
  Float32[0.0, 0.0, 0.0]           => Float32[0.0; 0.0; 0.0;;]
  Context(IdDict{Any, Any}())      => RefValue{Any}((cache = nothing,))

julia> run_isolated_code_on(gpu).grads
dx2 = 1.0f0
dx3 = (weight = nothing, bias = Float32[3.0;;], Οƒ = nothing)
primal = Dense(3 => 1)
dx1 = Tangent{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Flo...}(weight = NoTangent(), bias = Float32[3.0;;], Οƒ = NoTangent())
dx2 = Float32[1.0 1.0 1.0]
dx3 = (weight = Float32[-0.016965736 0.661252 1.6989294], bias = nothing, Οƒ = nothing)
primal = Dense(3 => 1)
dx1 = Tangent{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Flo...}(weight = Float32[-0.016965736 0.661252 1.6989294], bias = NoTangent(), Οƒ = NoTangent())
dx2 = Float32[-0.8001081 -0.8001081 -0.8001081; -0.07912698 -0.07912698 -0.07912698; -0.7474133 -0.7474133 -0.7474133]
dx3 = (weight = nothing, bias = Float32[-0.024003241; -0.15904522; -2.24224;;], Οƒ = nothing)
primal = Dense(5 => 3, leakyrelu)
dx1 = Tangent{Dense{typeof(leakyrelu), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Fl...}(weight = NoTangent(), bias = Float32[-0.024003241; -0.15904522; -2.24224;;], Οƒ = NoTangent())
dx2 = Float32[-0.008001081 -0.008001081 -0.008001081; -0.07912698 -0.00079126976 -0.07912698; -0.7474133 -0.7474133 -0.7474133]
dx3 = (weight = Float32[-0.014291735 -0.005607192 -0.011424695 -0.0070731747 -0.009476184; -0.09679431 -0.049017284 -0.10978342 -0.02157648 -0.07517205; -1.3350488 -0.5237904 -1.0672269 -0.6607339 -0.88520867], bias = nothing, Οƒ = nothing)
primal = Dense(5 => 3, leakyrelu)
dx1 = Tangent{Dense{typeof(leakyrelu), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Fl...}(weight = Float32[-0.014291735 -0.005607192 -0.011424695 -0.0070731747 -0.009476184; -0.09679431 -0.049017284 -0.10978342 -0.02157648 -0.07517205; -1.3350488 -0.5237904 -1.0672269 -0.6607339 -0.88520867], bias = NoTangent(), Οƒ = NoTangent())
dx3 = (layers = ((weight = Float32[-0.014291735 -0.005607192 -0.011424695 -0.0070731747 -0.009476184; -0.09679431 -0.049017284 -0.10978342 -0.02157648 -0.07517205; -1.3350488 -0.5237904 -1.0672269 -0.6607339 -0.88520867], bias = Float32[-0.024003241; -0.15904522; -2.24224;;], Οƒ = nothing), (weight = Float32[-0.016965736 0.661252 1.6989294], bias = Float32[3.0;;], Οƒ = nothing)),)
primal = Chain(Dense(5 => 3, leakyrelu), Dense(3 => 1))
dx1 = Tangent{Chain{Tuple{Dense{typeof(leakyrelu), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}...}(layers = Tangent{Tuple{Dense{typeof(leakyrelu), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuAr...}(Tangent{Dense{typeof(leakyrelu), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Fl...}(weight = Float32[-0.014291735 -0.005607192 -0.011424695 -0.0070731747 -0.009476184; -0.09679431 -0.049017284 -0.10978342 -0.02157648 -0.07517205; -1.3350488 -0.5237904 -1.0672269 -0.6607339 -0.88520867], bias = Float32[-0.024003241; -0.15904522; -2.24224;;], Οƒ = NoTangent()), Tangent{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Flo...}(weight = Float32[-0.016965736 0.661252 1.6989294], bias = Float32[3.0;;], Οƒ = NoTangent())),)
dx3 = (D = (layers = ((weight = Float32[-0.014291735 -0.005607192 -0.011424695 -0.0070731747 -0.009476184; -0.09679431 -0.049017284 -0.10978342 -0.02157648 -0.07517205; -1.3350488 -0.5237904 -1.0672269 -0.6607339 -0.88520867], bias = Float32[-0.024003241; -0.15904522; -2.24224;;], Οƒ = nothing), (weight = Float32[-0.016965736 0.661252 1.6989294], bias = Float32[3.0;;], Οƒ = nothing)),),)
primal = var"#50#52"{Chain{Tuple{Dense{typeof(leakyrelu), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}(Chain(Dense(5 => 3, leakyrelu), Dense(3 => 1)))
dx1 = Tangent{var"#50#52"{Chain{Tuple{Dense{typeof(leakyrelu), CuArray{Float32, 2, CUDA.Mem.D...}(D = Tangent{Chain{Tuple{Dense{typeof(leakyrelu), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}...}(layers = Tangent{Tuple{Dense{typeof(leakyrelu), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuAr...}(Tangent{Dense{typeof(leakyrelu), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Fl...}(weight = Float32[-0.014291735 -0.005607192 -0.011424695 -0.0070731747 -0.009476184; -0.09679431 -0.049017284 -0.10978342 -0.02157648 -0.07517205; -1.3350488 -0.5237904 -1.0672269 -0.6607339 -0.88520867], bias = Float32[-0.024003241; -0.15904522; -2.24224;;], Οƒ = NoTangent()), Tangent{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Flo...}(weight = Float32[-0.016965736 0.661252 1.6989294], bias = Float32[3.0;;], Οƒ = NoTangent())),),)
  2 dx = Float32[-0.008001081 -0.008001081 -0.008001081; -0.07912698 -0.00079126976 -0.07912698; -0.7474133 -0.7474133 -0.7474133] => ddx = Float32[-0.17363425 -0.1654667 -0.17363425; 0.05089833 -0.0006205989 0.05089833; 0.21575797 0.22180276 0.21575797]
  2 dx = Float32[-0.8001081 -0.8001081 -0.8001081; -0.07912698 -0.07912698 -0.07912698; -0.7474133 -0.7474133 -0.7474133] => ddx = Float32[-0.0017363424 -0.001654667 -0.0017363424; 0.05089833 -6.205989f-6 0.05089833; 0.21575797 0.22180276 0.21575797]
  2 dx = Float32[1.0 1.0 1.0] => ddx = Float32[-0.16389853 -0.16445392 -0.16389853]
  2 dx = 1.0 => ddx = -0.49225098
IdDict{Any, Any} with 5 entries:
  Context(IdDict{Any, Any}())      => RefValue{Any}((cache = nothing,))
  Float32[0.0, 0.0, 0.0]           => Float32[0.0; 0.0; 0.0;;]
  Float32[-0.475968 -0.167645 … 0… => Float32[-0.00354436 -0.00331468 … -0.00409469 -0.00290958; -0…
  Float32[-0.800108 -0.079127 -0.… => Float32[-0.00512735 0.10179 0.653319]
  Float32[0.0]                     => nothing

mcabbott avatar Jul 24 '22 21:07 mcabbott