Zygote.jl
Zygote.jl copied to clipboard
`sum` with CUDA and view on array errors
Hi,
I wanted to calculate a loss allocation free by using a view on the array and applying the loss on it. It's funny that even the forward pass has an allocation but the gradient with CUDA errors:
julia> using CUDA, NNlib
julia> x = randn((4096, 4096));
julia> varinfo(r"x")
name size summary
–––– ––––––––––– –––––––––––––––––––––––––
x 128.000 MiB 4096×4096 Matrix{Float64}
julia> iso = rand(Bool, 4096, 4096);
julia> f(x, isobject) = sum(x -> abs2(NNlib.relu(1 - x)), view(x, isobject))
f (generic function with 1 method)
julia> @time f(x, iso)
0.136329 seconds (62.44 k allocations: 68.321 MiB, 20.46% gc time, 20.26% compilation time)
1.6138125594596338e7
julia> @time f(x, iso)
0.080497 seconds (5 allocations: 63.989 MiB)
1.6138125594596338e7
julia> xc = CuArray(x)
4096×4096 CuArray{Float64, 2, CUDA.Mem.DeviceBuffer}:
julia> isoc = CuArray(iso)
4096×4096 CuArray{Bool, 2, CUDA.Mem.DeviceBuffer}:
julia> @time f(xc, isoc)
11.031044 seconds (21.32 M allocations: 1.447 GiB, 1.89% gc time, 95.75% compilation time: 2% of which was recompilation)
1.6138125594597608e7
julia> @time f(xc, isoc)
0.005610 seconds (587 allocations: 27.328 KiB)
1.6138125594597608e7
julia> using Zygote
julia> Zygote.gradient(x -> f(x, iso), x)
([-0.7454574712706954 0.0 … 0.0 0.0; 0.0 -2.2224158864106833 … 0.0 0.0; … ; 0.0 -2.1170359064329225 … 0.0 -1.093111758301367; -0.9269511276554552 -2.56023163044169 … 0.0 -1.8369961811609032],)
julia> Zygote.gradient(x -> f(x, isoc), xc)
ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore should be avoided.
If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
to enable scalar iteration globally or for the operations in question.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] errorscalar(op::String)
@ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:155
[3] _assertscalar(op::String, behavior::GPUArraysCore.ScalarIndexing)
@ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:128
[4] assertscalar(op::String)
@ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:116
[5] getindex(A::CuArray{Float64, 2, CUDA.Mem.DeviceBuffer}, I::Int64)
@ GPUArrays ~/.julia/packages/GPUArrays/Hd5Sk/src/host/indexing.jl:48 [inlined]
[6] reindex
@ ./subarray.jl:266 [inlined]
[7] getindex
@ ./subarray.jl:288 [inlined]
[8] iterate
@ ./abstractarray.jl:1214 [inlined]
[9] iterate
@ ./abstractarray.jl:1212 [inlined]
[10] iterate(::Base.Generator{Vector{Any}, IRTools.Inner.var"#52#53"{IRTools.Inner.var"#54#55"{IRTools.Inner.Block}}})
@ Base ./generator.jl:44 [inlined]
[11] collect(itr::Base.Generator{SubArray{Float64, 1, CuArray{…}, Tuple{…}, false}, ChainRules.var"#1636#1641"{Zygote.ZygoteRuleConfig{…}, var"#3#4"}})
@ Base ./array.jl:834
[12] rrule(config::Zygote.ZygoteRuleConfig{Zygote.Context{…}}, ::typeof(sum), f::var"#3#4", xs::SubArray{Float64, 1, CuArray{…}, Tuple{…}, false}; dims::Function)
@ ChainRules ~/.julia/packages/ChainRules/vaXA9/src/rulesets/Base/mapreduce.jl:102
[13] rrule
@ ChainRules ~/.julia/packages/ChainRules/vaXA9/src/rulesets/Base/mapreduce.jl:76 [inlined]
[14] chain_rrule
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:223 [inlined]
[15] macro expansion
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0 [inlined]
[16] _pullback
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:81 [inlined]
[17] f
@ ./REPL[13]:1 [inlined]
[18] _pullback(::Zygote.Context{false}, ::typeof(f), ::CuArray{Float64, 2, CUDA.Mem.DeviceBuffer}, ::CuArray{Bool, 2, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[19] #7
@ ./REPL[21]:1 [inlined]
[20] _pullback(ctx::Zygote.Context{false}, f::var"#7#8", args::CuArray{Float64, 2, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[21] pullback(f::Function, cx::Zygote.Context{false}, args::CuArray{Float64, 2, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:90
[22] pullback(f::Any, cx::ZygoteRules.AContext, args::Vararg{Any})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:88 [inlined]
[23] gradient(f::Function, args::CuArray{Float64, 2, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:147
[24] top-level scope
@ REPL[21]:1
[25] top-level scope
@ ~/.julia/packages/CUDA/htRwP/src/initialization.jl:206
Some type information was truncated. Use `show(err)` to see complete types.
Environment:
(@main) pkg> st
Status `~/.julia/environments/main/Project.toml`
[6e4b80f9] BenchmarkTools v1.4.0
[052768ef] CUDA v5.2.0
[d360d2e6] ChainRulesCore v1.20.1
[7a1cc6ca] FFTW v1.8.0
[872c559c] NNlib v0.9.11
[e88e6eb3] Zygote v0.6.69
julia> versioninfo
versioninfo (generic function with 2 methods)
julia> versioninfo()
Julia Version 1.10.0
Commit 3120989f39b (2023-12-25 18:01 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 24 × AMD Ryzen 9 5900X 12-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver3)
Threads: 39 on 24 virtual cores
Environment:
JULIA_NUM_THREADS = 24
julia> CUDA.versioninfo()
CUDA runtime 12.3, artifact installation
CUDA driver 12.3
NVIDIA driver 545.23.8
CUDA libraries:
- CUBLAS: 12.3.4
- CURAND: 10.3.4
- CUFFT: 11.0.12
- CUSOLVER: 11.5.4
- CUSPARSE: 12.2.0
- CUPTI: 21.0.0
- NVML: 12.0.0+545.23.8
Julia packages:
- CUDA: 5.2.0
- CUDA_Driver_jll: 0.7.0+1
- CUDA_Runtime_jll: 0.11.1+0
Toolchain:
- Julia: 1.10.0
- LLVM: 15.0.7
1 device:
0: NVIDIA GeForce RTX 3060 (sm_86, 10.358 GiB / 12.000 GiB available)
``
That's because there are rules for sum(::Function, ...) on contiguous GPU arrays but not views or wrappers: https://github.com/FluxML/Zygote.jl/blob/c0daccded5b9f91d31ceb889e4a97e74dd722a4e/src/lib/broadcast.jl#L374-L384
Somebody would have to figure out how to make those work for SubArrays.
Alternatively, you could try splitting the operations up to hit more advantageous rules:
f(x, isobject) = sum(abs2, NNlib.relu.(1 .- view(x, isobject)))
Yeah your function works but allocates another full array which I wanted to avoid.
But it actually looks like that the _pullback also allocates because of f.(xs)
res, back = _pullback(cx, (f, xs) -> sum(f.(xs); kws...), f, xs)
More than one full array, in fact. If the goal is to reduce allocations from unfused operations, you could keep the original code but make the view a normal indexing operation instead. This may or may not be faster, but it'll allow for using the manually fused x -> abs2(NNlib.relu(1 - x)).