Zygote.jl
Zygote.jl copied to clipboard
Device-to-host copies with GPU code
Currently any code that reduces GPU arrays to a single scalar value (like sum) does device-to-host copy of a single element at the end to return ::Number instead of ::GPUArray.
But each such transfer causes GPU synchronization, which kills the performance as we are not able to feed GPU fast enough.
Below is an MWE for Zygote.unbroadcast, but in a more complex code, like with Flux models I have up to 40 such transfers in a single forward+backward pass.
One way to fix this (which is what Python GPU frameworks do) is to return ::GPUArray instead of ::Number when we reduce to a single value. But this breaks things like projection and other stuff, although I'd be in favor of it.
Another is to go through every use case and update it to return a ::GPUArray with sum(x; dims=1:ndims(x)) or similar things...
julia> x = AMDGPU.rand(Float32, 16);
julia> Δ = ROCArray(ones(Float32, 1));
julia> y, back = Zygote.pullback(x) do x
sum(1f0 .- x; dims=1)
end;
julia> back(Δ);
[D to H] ROCArray{Float32, 1, AMDGPU.Runtime.Mem.HIPBuffer}: (1,) -> Vector{Float32}: (1,)
# Calling `Zygote.gradient` does one more device-to-host transfer than `Zygote.pullback` due to `sum`.
julia> ∇ = Zygote.gradient(x) do x
sum(1f0 .- x)
end;
[D to H] ROCArray{Float32, 1, AMDGPU.Runtime.Mem.HIPBuffer}: (1,) -> Vector{Float32}: (1,)
[D to H] ROCArray{Float32, 1, AMDGPU.Runtime.Mem.HIPBuffer}: (1,) -> Vector{Float32}: (1,)
For bigger Flux.jl models, such synchronizations may take up to a second, and all of them are under-the-hood, so the user cannot easily mitigate them.
Does it work if you return the result into a 0D array? e.g. something like:
sum!(ROCArray{Float32}(undef, ()), 1f0 .- x)
? If so, it would be pretty easy to define some generic function like:
sum0d(X) = sum!(similar(X,()), X)
@pxl-th does that work?
The problem is not that there is sum(...) at the end.
You can replace it with proposed sum0d or with sum(...; dims=1:ndims(x)) to return 0D array.
Then switch from Zygote.gradient to Zygote.pullback since gradient expects scalar at the end.
Here you can eliminate the synchronization.
But you can't eliminate it as easily with chain rules for broadcasting.
1f0 .- x broadcasts 1 to every element of x, then in gradient computation it uses this rule, which then uses sum(...; dims=:), which causes the synchronization since it returns scalar.
To eliminate synchronization here, you'd have to either:
- make GPU backends always return 0D arrays and require user to explicitly transfer from CPU to GPU. This is breaking, since lots of applications now expect
sumto return scalar, but IMO GPU-CPU transfers should be explicit. - Change chain rules to specifically handle GPU arrays and operate on 0D arrays. This looks more like a whack-a-mole to me, since you'd have to catch each such case and not all of them come from user code, like here with chain rules.
Interesting discussion. Some partial thoughts:
-
How much could be done in Flux not Zygote? We could change it to always accept a 0D array where it now accepts only a scalar (e.g. change
Flux.withgradient), and change it to always produce such arrays at least when working on GPU (e.g. make all loss functions callsum(x .- y; dims=(1,2,3)).) -
For the sum in
unbroadcastany change would have to be here. I don't quite know how breaking it would be for that to produce 0D arrays. I believe thatProjectTo{Float32}will always accept a 0D array and callonly-- of course that will synchronise but it may avoid breakage. It's possible that many gradients like1.0 .- xwill discard the 0D array without synchronising. -
Could consider making a
struct GPUScalar <: Realwhich wraps a one-element array. With promotion rules which turn it back into a Float32 as soon as you do anything (except broadcasting, etc). Perhaps that would be non-breaking?
@jpsamaroo has also suggested GPUScalar approach and I think ultimately it should be either this one or returning GPUArray that will fix these performance issues.
We just need to avoid this device-host copy.
Got GPUScalar approach working locally. Seems to be quite a minimal change and non-breaking (unless your code relies on the return type to be e.g. Float32 instead of GPUScalar{Float32}).
All GPUArrays and CUDA tests pass, will do more tests with Flux & Zygote and then open a PR. But, for the examples that I had, there are no device-host copies at all.
@maleadt is there a reason that sum(::GPUArray) causes device to host copies that we are missing here?
is there a reason that
sum(::GPUArray)causes device to host copies that we are missing here?
There's no other way to implement a reduction to a scalar; the value needs to be available, so we need to wait for the GPU to finish computing the value, aka. synchronize the device as @pxl-th mentioned.
@pxl-th, do you think GPUNumber can speed up optimizers like Adam and gradient descent? Currently, Optimisers.jl and other packages initialize the hyperparameters as scalars on the CPU like Adam(learning_rate, beta1 beta2). If they can be moved to the GPU, that should provide additional speedup in model training.
@pxl-th, do you think
GPUNumbercan speed up optimizers like Adam and gradient descent? Currently,Optimisers.jland other packages initialize the hyperparameters as scalars on the CPU likeAdam(learning_rate, beta1 beta2). If they can be moved to the GPU, that should provide additional speedup in model training.
Testing this hypothesis here:
using CUDA, BenchmarkTools
function f()
x = CUDA.rand(1024,1024, 100)
a1 = rand(Float32)
a2 = CUDA.rand(Float32, 1,1)
@btime CUDA.@sync $a1 * $x # 30.536 μs (70 allocations: 1.70 KiB)
@btime CUDA.@sync $a1 .* $x # 30.779 μs (70 allocations: 1.70 KiB)
@btime CUDA.@sync $a2 .* $x # 33.149 μs (105 allocations: 2.25 KiB)
nothing
end
f()
Looks like theres no speedup to be gained there.
Looks like theres no speedup to be gained there.
Yeah, because you're synchronizing. The whole point of GPUNumber (or whatever it will be named) is that pending values can be used as inputs without having to synchronize the GPU. It also only matters when acquiring such an unmaterialized value from one GPU kernel and feeding it to another; when feeding simple random numbers to a kernel it does not matter whether you use the raw value or a 1-element array, in fact it will only slow things down because the value can't get loaded from the parameter address space anymore.