Zygote.jl
Zygote.jl copied to clipboard
Incorrect derivative of `getindex()` with repeating indices on CuArrays
using Zygote
using CUDA
f(A, I) = sum(A[I])
A = rand(4)
I = [1, 3, 1]
# CPU - everything is OK
Zygote.gradient(f, A, I)
# ==> ([2.0, 0.0, 1.0, 0.0], nothing)
# GPU - dA[1] is incorrect
Zygote.gradient(f, cu(A), cu(I))
# => (Float32[1.0, 0.0, 1.0, 0.0], nothing)
I believe CPU version comes from ChainRules.jl which correctly adds several derivatives to dA[1]
, but I'm not sure what code is used for CUDA version.
Here is how I came to this issue and how I try to resolve it in Yota.
Similar to #600, see also https://github.com/JuliaGPU/CUDA.jl/issues/89 and https://github.com/JuliaLang/julia/pull/31407 .
My brief reading of the linked CUDA issue suggests this can't be fixed? Could we add an error here? Just spent a bunch of time discovering subtly wrong results in my code / reducing to a MWE / eventually finding this already-filed issue. :cry:
I thought https://github.com/JuliaDiff/ChainRules.jl/pull/655 was the last word on this. If that works then we "just" need to do https://github.com/FluxML/Zygote.jl/pull/1328.