Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Gradient fails to deal with GPU movements

Open renatobellotti opened this issue 3 years ago • 8 comments

Hi,

I'm trying to calculate a "masked minimum" of an array, i. e. take the minimum among all values fulfilling a certain condition. On the CPU, this works:

using Zygote


function masked_minimum(x::AbstractArray{T, N}, mask::AbstractArray{T, N}) where {T, N}
    x_masked = x[mask .== one(T)]
    return minimum(x_masked)
end

A = sprand(200, 100, 0.8)
x = rand(100)
mask = round.(rand(200))

masked_minimum(A * x, mask);
Zygote.gradient(x -> masked_minimum(A * x, mask), x)[1];

However, on the GPU it does not work:

function masked_minimum(x::AbstractArray{T, N}, mask::AbstractArray{T, N}) where {T, N}
    # Copy to the CPU to allow for fast indexing.
    mask = Array(mask)
    x_masked = Array(x)[mask .== one(T)]
    return minimum(x_masked)
end

A = cu(sprand(200, 100, 0.8))
x = cu(rand(100))
mask = cu(round.(rand(200)))

masked_minimum(A * x, mask);
Zygote.gradient(x -> masked_minimum(A * x, mask), x)[1]

Error message:

GPU compilation of kernel #broadcast_kernel#17(CUDA.CuKernelContext, CuDeviceMatrix{Float32, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Extruded{Vector{Float32}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Adjoint{Float32, CuDeviceVector{Float32, 1}}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, Int64) failed
KernelError: passing and using non-bitstype argument

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Extruded{Vector{Float32}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Adjoint{Float32, CuDeviceVector{Float32, 1}}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, which is not isbits:
  .args is of type Tuple{Base.Broadcast.Extruded{Vector{Float32}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Adjoint{Float32, CuDeviceVector{Float32, 1}}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}} which is not isbits.
    .1 is of type Base.Broadcast.Extruded{Vector{Float32}, Tuple{Bool}, Tuple{Int64}} which is not isbits.
      .x is of type Vector{Float32} which is not isbits.



Stacktrace:
  [1] check_invocation(job::GPUCompiler.CompilerJob)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/N98un/src/validation.jl:88
  [2] macro expansion
    @ ~/.julia/packages/GPUCompiler/N98un/src/driver.jl:417 [inlined]
  [3] macro expansion
    @ ~/.julia/packages/TimerOutputs/jgSVI/src/TimerOutput.jl:252 [inlined]
  [4] macro expansion
    @ ~/.julia/packages/GPUCompiler/N98un/src/driver.jl:416 [inlined]
  [5] emit_asm(job::GPUCompiler.CompilerJob, ir::LLVM.Module; strip::Bool, validate::Bool, format::LLVM.API.LLVMCodeGenFileType)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/N98un/src/utils.jl:64
  [6] cufunction_compile(job::GPUCompiler.CompilerJob, ctx::LLVM.Context)
    @ CUDA ~/.julia/packages/CUDA/DfvRa/src/compiler/execution.jl:354
  [7] #224
    @ ~/.julia/packages/CUDA/DfvRa/src/compiler/execution.jl:347 [inlined]
  [8] JuliaContext(f::CUDA.var"#224#225"{GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams, GPUCompiler.FunctionSpec{GPUArrays.var"#broadcast_kernel#17", Tuple{CUDA.CuKernelContext, CuDeviceMatrix{Float32, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Extruded{Vector{Float32}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Adjoint{Float32, CuDeviceVector{Float32, 1}}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, Int64}}}})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/N98un/src/driver.jl:76
  [9] cufunction_compile(job::GPUCompiler.CompilerJob)
    @ CUDA ~/.julia/packages/CUDA/DfvRa/src/compiler/execution.jl:346
 [10] cached_compilation(cache::Dict{UInt64, Any}, job::GPUCompiler.CompilerJob, compiler::typeof(CUDA.cufunction_compile), linker::typeof(CUDA.cufunction_link))
    @ GPUCompiler ~/.julia/packages/GPUCompiler/N98un/src/cache.jl:90
 [11] cufunction(f::GPUArrays.var"#broadcast_kernel#17", tt::Type{Tuple{CUDA.CuKernelContext, CuDeviceMatrix{Float32, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Extruded{Vector{Float32}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Adjoint{Float32, CuDeviceVector{Float32, 1}}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, Int64}}; name::Nothing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ CUDA ~/.julia/packages/CUDA/DfvRa/src/compiler/execution.jl:299
 [12] cufunction
    @ ~/.julia/packages/CUDA/DfvRa/src/compiler/execution.jl:293 [inlined]
 [13] macro expansion
    @ ~/.julia/packages/CUDA/DfvRa/src/compiler/execution.jl:102 [inlined]
 [14] #launch_heuristic#248
    @ ~/.julia/packages/CUDA/DfvRa/src/gpuarrays.jl:17 [inlined]
 [15] _copyto!
    @ ~/.julia/packages/GPUArrays/Hyss4/src/host/broadcast.jl:63 [inlined]
 [16] copyto!
    @ ~/.julia/packages/GPUArrays/Hyss4/src/host/broadcast.jl:46 [inlined]
 [17] copy
    @ ~/.julia/packages/GPUArrays/Hyss4/src/host/broadcast.jl:37 [inlined]
 [18] materialize
    @ ./broadcast.jl:860 [inlined]
 [19] broadcast(::typeof(*), ::Vector{Float32}, ::Adjoint{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}})
    @ Base.Broadcast ./broadcast.jl:798
 [20] *
    @ <redacted>/julia-1.7.3/share/julia/stdlib/v1.7/LinearAlgebra/src/adjtrans.jl:297 [inlined]
 [21] #1404
    @ ~/.julia/packages/ChainRules/EyLkg/src/rulesets/Base/arraymath.jl:36 [inlined]
 [22] unthunk
    @ ~/.julia/packages/ChainRulesCore/ctmSK/src/tangent_types/thunks.jl:199 [inlined]
 [23] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/D7j8v/src/compiler/chainrules.jl:104 [inlined]
 [24] map
    @ ./tuple.jl:223 [inlined]
 [25] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/D7j8v/src/compiler/chainrules.jl:105 [inlined]
 [26] ZBack
    @ ~/.julia/packages/Zygote/D7j8v/src/compiler/chainrules.jl:205 [inlined]
 [27] Pullback
    @ ./In[281]:13 [inlined]
 [28] (::typeof(∂(#351)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:0
 [29] (::Zygote.var"#60#61"{typeof(∂(#351))})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface.jl:41
 [30] gradient(f::Function, args::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface.jl:76
 [31] top-level scope
    @ In[281]:13
 [32] eval
    @ ./boot.jl:373 [inlined]
 [33] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
    @ Base ./loading.jl:1196

Is this a bug or am I doing something wrong?

renatobellotti avatar Sep 13 '22 19:09 renatobellotti

Neither Zygote nor ChainRules defines a rule for the Array constructor on GPU arrays. However, you could copy the one Flux defines for this purpose (note: this is piracy). I do wonder if there's not a way to express your masked minimum without moving everything onto the CPU, however. minimum at least should work on CuArrays, as does x[mask .== one(T)].

ToucheSir avatar Sep 13 '22 23:09 ToucheSir

Neither Zygote nor ChainRules defines a rule for the Array constructor on GPU arrays. However, you could copy the one Flux defines for this purpose (note: this is piracy).

Thanks for the fast answer!

What do you mean with "piracy"? If I add a comment with the origin and license of the code it should be allowed since Flux is licensed under the MIT license. I'm not a lawyer, though.

I do wonder if there's not a way to express your masked minimum without moving everything onto the CPU, however. minimum at least should work on CuArrays, as does x[mask .== one(T)].

Good point. The indexing operations are very slow on the GPU, especially since my x is typically of size approx. 11'000. On top, the x is the result of other operations that must happen on the GPU for performance reasons.

I will try the Flux rule.

renatobellotti avatar Sep 14 '22 06:09 renatobellotti

Unfortunately, adding the new rrule does not resolve the problem:

# The following function is taken from (2022-09-14) Flux.jl, which is licensed under the MIT "Expat" license:
# https://github.com/FluxML/Flux.jl/blob/v0.13.6/src/functor.jl#L121-L123

# Copyright (c) 2016-19: Julia Computing, INc., Mike Innes and Contributors
# 
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# 
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
# 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

using Zygote, CUDA, ChainRulesCore, SparseArrays  # [Edit: added to make the example run]

function ChainRulesCore.rrule(::Type{Array}, x::CUDA.CuArray)
  Array(x), dx -> (NoTangent(), CUDA.cu(unthunk(dx)),)
end


function masked_minimum(x::AbstractArray{T, N}, mask::AbstractArray{T, N}) where {T, N}
    # Copy to the CPU to allow for fast indexing.
    mask = Array(mask)
    x_masked = Array(x)[mask .== one(T)]
    return minimum(x_masked)
end

A = cu(sprand(200, 100, 0.8))
x = cu(rand(100))
mask = cu(round.(rand(200)))

masked_minimum(A * x, mask);
Zygote.gradient(x -> masked_minimum(A * x, mask), x)[1]
GPU compilation of kernel #broadcast_kernel#17(CUDA.CuKernelContext, CuDeviceMatrix{Float32, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Extruded{Vector{Float32}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Adjoint{Float32, CuDeviceVector{Float32, 1}}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, Int64) failed
KernelError: passing and using non-bitstype argument

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Extruded{Vector{Float32}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Adjoint{Float32, CuDeviceVector{Float32, 1}}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, which is not isbits:
  .args is of type Tuple{Base.Broadcast.Extruded{Vector{Float32}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Adjoint{Float32, CuDeviceVector{Float32, 1}}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}} which is not isbits.
    .1 is of type Base.Broadcast.Extruded{Vector{Float32}, Tuple{Bool}, Tuple{Int64}} which is not isbits.
      .x is of type Vector{Float32} which is not isbits.



Stacktrace:
  [1] check_invocation(job::GPUCompiler.CompilerJob)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/N98un/src/validation.jl:88
  [2] macro expansion
    @ ~/.julia/packages/GPUCompiler/N98un/src/driver.jl:417 [inlined]
  [3] macro expansion
    @ ~/.julia/packages/TimerOutputs/jgSVI/src/TimerOutput.jl:252 [inlined]
  [4] macro expansion
    @ ~/.julia/packages/GPUCompiler/N98un/src/driver.jl:416 [inlined]
  [5] emit_asm(job::GPUCompiler.CompilerJob, ir::LLVM.Module; strip::Bool, validate::Bool, format::LLVM.API.LLVMCodeGenFileType)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/N98un/src/utils.jl:64
  [6] cufunction_compile(job::GPUCompiler.CompilerJob, ctx::LLVM.Context)
    @ CUDA ~/.julia/packages/CUDA/DfvRa/src/compiler/execution.jl:354
  [7] #224
    @ ~/.julia/packages/CUDA/DfvRa/src/compiler/execution.jl:347 [inlined]
  [8] JuliaContext(f::CUDA.var"#224#225"{GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams, GPUCompiler.FunctionSpec{GPUArrays.var"#broadcast_kernel#17", Tuple{CUDA.CuKernelContext, CuDeviceMatrix{Float32, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Extruded{Vector{Float32}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Adjoint{Float32, CuDeviceVector{Float32, 1}}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, Int64}}}})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/N98un/src/driver.jl:76
  [9] cufunction_compile(job::GPUCompiler.CompilerJob)
    @ CUDA ~/.julia/packages/CUDA/DfvRa/src/compiler/execution.jl:346
 [10] cached_compilation(cache::Dict{UInt64, Any}, job::GPUCompiler.CompilerJob, compiler::typeof(CUDA.cufunction_compile), linker::typeof(CUDA.cufunction_link))
    @ GPUCompiler ~/.julia/packages/GPUCompiler/N98un/src/cache.jl:90
 [11] cufunction(f::GPUArrays.var"#broadcast_kernel#17", tt::Type{Tuple{CUDA.CuKernelContext, CuDeviceMatrix{Float32, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Extruded{Vector{Float32}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Adjoint{Float32, CuDeviceVector{Float32, 1}}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, Int64}}; name::Nothing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ CUDA ~/.julia/packages/CUDA/DfvRa/src/compiler/execution.jl:299
 [12] cufunction
    @ ~/.julia/packages/CUDA/DfvRa/src/compiler/execution.jl:293 [inlined]
 [13] macro expansion
    @ ~/.julia/packages/CUDA/DfvRa/src/compiler/execution.jl:102 [inlined]
 [14] #launch_heuristic#248
    @ ~/.julia/packages/CUDA/DfvRa/src/gpuarrays.jl:17 [inlined]
 [15] _copyto!
    @ ~/.julia/packages/GPUArrays/Hyss4/src/host/broadcast.jl:63 [inlined]
 [16] copyto!
    @ ~/.julia/packages/GPUArrays/Hyss4/src/host/broadcast.jl:46 [inlined]
 [17] copy
    @ ~/.julia/packages/GPUArrays/Hyss4/src/host/broadcast.jl:37 [inlined]
 [18] materialize
    @ ./broadcast.jl:860 [inlined]
 [19] broadcast(::typeof(*), ::Vector{Float32}, ::Adjoint{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}})
    @ Base.Broadcast ./broadcast.jl:798
 [20] *
    @ <redacted>/julia-1.7.3/share/julia/stdlib/v1.7/LinearAlgebra/src/adjtrans.jl:297 [inlined]
 [21] #1404
    @ ~/.julia/packages/ChainRules/EyLkg/src/rulesets/Base/arraymath.jl:36 [inlined]
 [22] unthunk
    @ ~/.julia/packages/ChainRulesCore/ctmSK/src/tangent_types/thunks.jl:199 [inlined]
 [23] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/D7j8v/src/compiler/chainrules.jl:104 [inlined]
 [24] map
    @ ./tuple.jl:223 [inlined]
 [25] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/D7j8v/src/compiler/chainrules.jl:105 [inlined]
 [26] ZBack
    @ ~/.julia/packages/Zygote/D7j8v/src/compiler/chainrules.jl:205 [inlined]
 [27] Pullback
    @ ./In[40]:28 [inlined]
 [28] (::typeof(∂(#39)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:0
 [29] (::Zygote.var"#60#61"{typeof(∂(#39))})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface.jl:41
 [30] gradient(f::Function, args::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface.jl:76
 [31] top-level scope
    @ In[40]:28
 [32] eval
    @ ./boot.jl:373 [inlined]
 [33] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
    @ Base ./loading.jl:1196

renatobellotti avatar Sep 14 '22 07:09 renatobellotti

Here's one way to avoid indexing:

julia> masked_minimum(A * x, mask)
13.674786f0

julia> masked_minimum_2(val, mask) = mapreduce(min, val, mask) do x, y
           ifelse(isone(y), x, typemax(typeof(x)))
       end;

julia> masked_minimum_2(A * x, mask)  # works but not with Zygote
13.674786f0

julia> function masked_minimum_3(val, mask) 
         tmp = broadcast(val, mask) do x, y
           ifelse(isone(y), x, typemax(typeof(x)))
         end
         minimum(tmp)
       end;

julia> masked_minimum_3(A * x, mask)
13.674786f0

julia> Zygote.gradient(x -> masked_minimum_3(A*x, mask), x)[1] |> summary
"100-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}"

Note BTW that Julia has adopted the word "piracy" to mean one package adding methods to a function it doesn't own, for types it doesn't own. This means that loading the package may alter the behaviour of apparently unrelated code, hence it's regarded as a bad thing. (At least in published packages, at home you can do whatever you like.)

mcabbott avatar Sep 14 '22 11:09 mcabbott

It isn't calling the rrule you defined. And I think the reason is this line https://github.com/FluxML/Zygote.jl/blob/de078c84ce0a1ee517e9e929f0bb6b97b697e23e/src/lib/array.jl#L6 since Zygote's own @adjoint rules take precedence:

julia> function ChainRulesCore.rrule(::Type{Array}, x::CUDA.CuArray)
         println("Array rrule")
         Array(x), dx -> (NoTangent(), CUDA.cu(unthunk(dx)),)
       end;

julia> function masked_minimum(x::AbstractArray{T, N}, mask::AbstractArray{T, N}) where {T, N}
           println("masked_minimum")
           mask_cpu = Array(mask)
           x_cpu = Array(x)
           x_masked = x_cpu[mask_cpu .== one(T)]
           return minimum(x_masked)
       end;

julia> Zygote.gradient(y -> masked_minimum(y, mask), A*x)[1] |> summary
masked_minimum
"200-element Vector{Float32}"

julia> Zygote.@adjoint Array(xs::CUDA.CuArray) = Array(xs), ȳ -> (CUDA.cu(@show ȳ),);

julia> Zygote.refresh()

julia> Zygote.gradient(y -> masked_minimum(y, mask), A*x)[1] |> summary
masked_minimum
ȳ = Float32[0.0, 0.0, ... 0.0]
"200-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}"

That line is simply wrong, and should be fixed.

mcabbott avatar Sep 14 '22 14:09 mcabbott

Thank you very much for the suggestions and the explanations about "piracy", I haven't known that!

A question about how to debug Zygote code in general: How have you found out that the rrule is never called? Have you added a print statement to it or have you inferred this from the error message?

renatobellotti avatar Sep 15 '22 07:09 renatobellotti

Sometimes there are hints in the stack trace, but often printing is the best way to know what's going on. (The fancier @info and friends don't play well with Zygote, BTW.)

mcabbott avatar Sep 15 '22 18:09 mcabbott

Sometimes there are hints in the stack trace, but often printing is the best way to know what's going on. (The fancier @info and friends don't play well with Zygote, BTW.)

Ok, thanks!

renatobellotti avatar Sep 21 '22 07:09 renatobellotti