Zygote.jl
Zygote.jl copied to clipboard
CUDA gradient error with broadcasting
This used to work, but currently fails on CI. I tried with julia v1.11, Zygote v0.7.10 and CUDA v5.8.5
julia> using Zygote, CUDA
julia> a = Float32.(1:9)
9-element Vector{Float32}:
1.0
2.0
3.0
4.0
5.0
6.0
7.0
8.0
9.0
julia> a_gpu = a |> cu
9-element CuArray{Float32, 1, CUDA.DeviceMemory}:
1.0
2.0
3.0
4.0
5.0
6.0
7.0
8.0
9.0
julia> g3 = gradient(x -> sum(x .^ 3) / count(x .> 3), a)[1]
9-element Vector{Float32}:
0.5
2.0
4.5
8.0
12.5
18.0
24.5
32.0
40.5
julia> gradient(x -> sum(x .^ 3) / count(x .> 3), a_gpu)[1] # WRONG GRADIENT
9-element CuArray{Float32, 1, CUDA.DeviceMemory}:
0.42857146
1.7142859
3.8571432
6.8571434
10.714287
15.428573
21.000002
27.428574
34.714287
FWIW this is wrong with other GPUArrays too:
julia> using JLArrays, Zygote
julia> a = Float32.(1:9)
9-element Vector{Float32}:
julia> a_gpu = a |> jl
9-element JLArray{Float32, 1}:
julia> g3 = Zygote.gradient(x -> sum(x .^ 3) / count(x .> 3), a)[1]
9-element Vector{Float32}:
0.5
2.0
julia> Zygote.gradient(x -> sum(x .^ 3) / count(x .> 3), a_gpu)[1] # WRONG
9-element JLArray{Float32, 1}:
0.42857146
1.7142859
julia> Zygote.gradient(x -> sum(x .^ 3) / 6, a_gpu)[1] # fine!
9-element JLArray{Float32, 1}:
0.5
2.0
4.5
julia> Zygote.gradient(x -> count(x .> 3), a_gpu) # fine
(nothing,)
julia> Zygote.gradient(x -> sum(x .> 3), a_gpu)
(nothing,)
Zygote v0.7.10 + Julia 1.11.6