Zygote.jl
Zygote.jl copied to clipboard
Moving BitArray to gpu fails.
When I move a BitArray to the gpu from within a gradient call, I get the following error.
using CUDA, Zygote
function f(a, b)
sum(a .* b)
end
a = rand(10, 10, 10)|>cu
b = rand(10, 10, 10) .> 0.5 #BitArray
Zygote.gradient(a) do a
f(a, cu(b))
end
ERROR: Compiling Tuple{typeof(GPUArrays.indexstyle),BitArray{3}}: try/catch is not supported.
Stacktrace:
[1] error(::String) at .\error.jl:33
[2] instrument(::IRTools.Inner.IR) at C:\Users\jules\.julia\packages\Zygote\Xgcgs\src\compiler\reverse.jl:89
[3] #Primal#20 at C:\Users\jules\.julia\packages\Zygote\Xgcgs\src\compiler\reverse.jl:170 [inlined]
[4] Zygote.Adjoint(::IRTools.Inner.IR; varargs::Nothing, normalise::Bool) at C:\Users\jules\.julia\packages\Zygote\Xgcgs\src\compiler\reverse.jl:283
[5] _lookup_grad(::Type{T} where T) at C:\Users\jules\.julia\packages\Zygote\Xgcgs\src\compiler\emit.jl:101
[6] #s2789#1235 at C:\Users\jules\.julia\packages\Zygote\Xgcgs\src\compiler\interface2.jl:39 [inlined]
[7] #s2789#1235(::Any, ::Any, ::Any) at .\none:0
[8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any,N} where N) at .\boot.jl:527
[9] convert at C:\Users\jules\.julia\packages\GPUArrays\eVYIC\src\host\construction.jl:75 [inlined]
[10] (::typeof(∂(convert)))(::CuArray{Float32,3}) at C:\Users\jules\.julia\packages\Zygote\Xgcgs\src\compiler\interface2.jl:0
[11] adapt_storage at C:\Users\jules\.julia\packages\CUDA\dZvbp\src\array.jl:341 [inlined]
[12] (::typeof(∂(adapt_storage)))(::CuArray{Float32,3}) at C:\Users\jules\.julia\packages\Zygote\Xgcgs\src\compiler\interface2.jl:0
[13] adapt_structure at C:\Users\jules\.julia\packages\Adapt\zeelH\src\Adapt.jl:42 [inlined]
[14] adapt at C:\Users\jules\.julia\packages\Adapt\zeelH\src\Adapt.jl:40 [inlined]
[15] cu at C:\Users\jules\.julia\packages\CUDA\dZvbp\src\array.jl:347 [inlined]
[16] (::typeof(∂(cu)))(::CuArray{Float32,3}) at C:\Users\jules\.julia\packages\Zygote\Xgcgs\src\compiler\interface2.jl:0
[17] #11 at .\REPL[1]:2 [inlined]
[18] (::typeof(∂(#11)))(::Float32) at C:\Users\jules\.julia\packages\Zygote\Xgcgs\src\compiler\interface2.jl:0
[19] (::Zygote.var"#41#42"{typeof(∂(#11))})(::Float32) at C:\Users\jules\.julia\packages\Zygote\Xgcgs\src\compiler\interface.jl:45
[20] gradient(::Function, ::CuArray{Float32,3}) at C:\Users\jules\.julia\packages\Zygote\Xgcgs\src\compiler\interface.jl:54
[21] top-level scope at REPL[1]:1
I'm not quite sure why this happens but it doesn't happen with a regular array so I'm assuming this is a bug? Thanks
Any updates on this?
Looks like a different error now, coming from the call stack of convert(CuArray, BitArray)
:
julia> Zygote.gradient(x -> sum(cu(x)), ones(Bool, 1))
(Bool[1],)
julia> Zygote.gradient(x -> sum(cu(x)), BitArray(ones(1)))
ERROR: MethodError: no method matching iterate(::ErrorException)
Closest candidates are:
iterate(::Union{LinRange, StepRangeLen}) at range.jl:664
iterate(::Union{LinRange, StepRangeLen}, ::Int64) at range.jl:664
iterate(::T) where T<:Union{Base.KeySet{var"#s79", var"#s78"} where {var"#s79", var"#s78"<:Dict}, Base.ValueIterator{var"#s77"} where var"#s77"<:Dict} at dict.jl:693
...
Stacktrace:
[1] indexed_iterate(I::ErrorException, i::Int64)
@ Base ./tuple.jl:89
[2] #s3010#1217
@ ~/.julia/packages/Zygote/0da6K/src/compiler/interface2.jl:28 [inlined]
[3] var"#s3010#1217"(::Any, ctx::Any, f::Any, args::Any)
@ Zygote ./none:0
[4] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any, N} where N)
@ Core ./boot.jl:571
[5] _pullback
@ ~/.julia/packages/GPUArrays/8dzSJ/src/host/construction.jl:83 [inlined]
[6] _pullback(::Zygote.Context, ::typeof(convert), ::Type{CuArray}, ::BitVector)
@ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface2.jl:0
[7] _pullback
@ ~/.julia/packages/CUDA/fRSUT/src/array.jl:388 [inlined]
[8] _pullback
@ ~/.julia/packages/Adapt/RGNRk/src/Adapt.jl:42 [inlined]
[9] _pullback
@ ~/.julia/packages/Adapt/RGNRk/src/Adapt.jl:40 [inlined]
[10] _pullback
@ ~/.julia/packages/CUDA/fRSUT/src/array.jl:403 [inlined]
[11] _pullback
@ ./REPL[50]:1 [inlined]
[12] _pullback(ctx::Zygote.Context, f::var"#31#32", args::BitVector)
@ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface2.jl:0
[13] _pullback(f::Function, args::BitVector)
@ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface.jl:34
[14] pullback(f::Function, args::BitVector)
@ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface.jl:40
[15] gradient(f::Function, args::BitVector)
@ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface.jl:58
[16] top-level scope
@ REPL[50]:1
[17] top-level scope
@ ~/.julia/packages/CUDA/fRSUT/src/initialization.jl:52
Looking at the adjoint code, GPUArrays can intercept the dense bool array conversion, but not the bitarray one:
julia> @code_adjoint convert(CuArray, ones(Bool, 1))
Zygote.Adjoint(1: (%4, %5 :: Zygote.Context, %1, %2, %3)
%6 = Core.apply_type(%2, $(QuoteNode(Bool)))
%7 = Zygote._pullback(%5, GPUArrays.size, %3)
%8 = Base.getindex(%7, 1)
%9 = Base.getindex(%7, 2)
%10 = Zygote._pullback(%5, GPUArrays.similar, %6, %8)
%11 = Base.getindex(%10, 1)
%12 = Base.getindex(%10, 2)
%13 = Zygote._pullback(%5, GPUArrays.copyto!, %11, %3)
%14 = Base.getindex(%13, 1)
%15 = Base.getindex(%13, 2)
return %14, 1: (%1)
%2 = (@15)(%1)
%3 = Zygote.gradindex(%2, 2)
%4 = Zygote.gradindex(%2, 3)
%5 = (@12)(%3)
%6 = Zygote.gradindex(%5, 2)
%7 = Zygote.gradindex(%5, 3)
%8 = (@9)(%7)
%9 = Zygote.gradindex(%8, 2)
%10 = Zygote.accum(%4, %9)
%11 = Zygote.tuple(nothing, nothing, %10)
return %11)
julia> @code_adjoint convert(CuArray, BitArray(ones(1)))
Zygote.Adjoint(1: (%4, %5 :: Zygote.Context, %1, %2, %3)
%6 = Base.IteratorSize
%7 = Zygote._pullback(%5, %6, %3)
%8 = Base.getindex(%7, 1)
%9 = Base.getindex(%7, 2)
%10 = Zygote._pullback(%5, GPUArrays.indexstyle, %3)
%11 = Base.getindex(%10, 1)
%12 = Base.getindex(%10, 2)
%13 = Base.IteratorEltype
%14 = Zygote._pullback(%5, %13, %3)
%15 = Base.getindex(%14, 1)
%16 = Base.getindex(%14, 2)
%17 = Zygote._pullback(%5, GPUArrays.isbits, %3)
%18 = Base.getindex(%17, 1)
%19 = Base.getindex(%17, 2)
br 6 (1) unless %18
br 2
2:
%20 = Base.HasShape
%21 = Zygote._pullback(%5, GPUArrays.isa, %8, %20)
%22 = Base.getindex(%21, 1)
%23 = Base.getindex(%21, 2)
br 6 (2) unless %22
br 3
3:
%24 = Zygote._pullback(%5, GPUArrays.:!=, %11, GPUArrays.nothing)
%25 = Base.getindex(%24, 1)
%26 = Base.getindex(%24, 2)
br 6 (3) unless %25
br 4
4:
%27 = Base.HasEltype
%28 = Zygote._pullback(%5, GPUArrays.isa, %15, %27)
%29 = Base.getindex(%28, 1)
%30 = Base.getindex(%28, 2)
br 6 (4) unless %29
br 5
5:
%31 = Zygote._pullback(%5, GPUArrays.eltype, %3)
%32 = Base.getindex(%31, 1)
%33 = Base.getindex(%31, 2)
%34 = Zygote._pullback(%5, GPUArrays.eltype_or, %2, %32)
%35 = Base.getindex(%34, 1)
%36 = Base.getindex(%34, 2)
%37 = Zygote._pullback(%5, GPUArrays.size, %3)
%38 = Base.getindex(%37, 1)
%39 = Base.getindex(%37, 2)
%40 = Zygote._pullback(%5, GPUArrays.similar, %2, %35, %38)
%41 = Base.getindex(%40, 1)
%42 = Base.getindex(%40, 2)
%43 = Zygote._pullback(%5, GPUArrays.gpu_call, GPUArrays.collect_kernel, %41, %3, %11)
%44 = Base.getindex(%43, 1)
%45 = Base.getindex(%43, 2)
br 7 (%41, 1)
6: (%53 :: UInt8)
%46 = Zygote._pullback(%5, GPUArrays.collect, %3)
%47 = Base.getindex(%46, 1)
%48 = Base.getindex(%46, 2)
%49 = Zygote._pullback(%5, GPUArrays.convert, %2, %47)
%50 = Base.getindex(%49, 1)
%51 = Base.getindex(%49, 2)
br 7 (%50, 2)
7: (%52, %54 :: UInt8)
return %52, 1: (%1)
%2 = @54 !== 0x01
br 3 unless %2
br 2
2:
%3 = @53 !== 0x01
%4 = @53 !== 0x02
%5 = @53 !== 0x03
%6 = (@51)(%1)
%7 = Zygote.gradindex(%6, 2)
%8 = Zygote.gradindex(%6, 3)
%9 = (@48)(%8)
%10 = Zygote.gradindex(%9, 2)
br 7 (%7, %10, nothing, nothing, nothing) unless %3
br 6 (%7, %10, nothing, nothing) unless %4
br 5 (%7, %10, nothing, nothing) unless %5
br 4 (%7, %10, nothing)
3:
%11 = (@45)(nothing)
%12 = Zygote.gradindex(%11, 3)
%13 = Zygote.gradindex(%11, 4)
%14 = Zygote.gradindex(%11, 5)
%15 = Zygote.accum(%1, %12)
%16 = (@42)(%15)
%17 = Zygote.gradindex(%16, 2)
%18 = Zygote.gradindex(%16, 3)
%19 = Zygote.gradindex(%16, 4)
%20 = (@39)(%19)
%21 = Zygote.gradindex(%20, 2)
%22 = (@36)(%18)
%23 = Zygote.gradindex(%22, 2)
%24 = Zygote.gradindex(%22, 3)
%25 = (@33)(%24)
%26 = Zygote.gradindex(%25, 2)
%27 = Zygote.accum(%17, %23)
%28 = Zygote.accum(%13, %21, %26)
br 4 (%27, %28, %14)
4: (%29, %30, %31)
%32 = (@30)(nothing)
%33 = Zygote.gradindex(%32, 2)
%34 = Zygote.gradindex(%32, 3)
br 5 (%29, %30, %33, %31)
5: (%35, %36, %37, %38)
%39 = (@26)(nothing)
%40 = Zygote.gradindex(%39, 2)
%41 = Zygote.accum(%38, %40)
br 6 (%35, %36, %37, %41)
6: (%42, %43, %44, %45)
%46 = (@23)(nothing)
%47 = Zygote.gradindex(%46, 2)
%48 = Zygote.gradindex(%46, 3)
br 7 (%42, %43, %47, %44, %45)
7: (%49, %50, %51, %52, %53)
%54 = (@19)(nothing)
%55 = Zygote.gradindex(%54, 2)
%56 = (@16)(%52)
%57 = Zygote.gradindex(%56, 1)
%58 = Zygote.gradindex(%56, 2)
%59 = (@12)(%53)
%60 = Zygote.gradindex(%59, 2)
%61 = (@9)(%51)
%62 = Zygote.gradindex(%61, 1)
%63 = Zygote.gradindex(%61, 2)
%64 = Zygote.accum(%50, %55, %58, %60, %63)
%65 = Zygote.tuple(nothing, %49, %64)
return %65)
To give you an update, this has been fixed upstream in Flux, but we ought to port over the improved implementation in https://github.com/FluxML/Flux.jl/pull/1708. In the meantime if you happen to be using Flux anyhow, gpu
and cpu
should work as expected.