Zygote.jl
Zygote.jl copied to clipboard
Complex broadcasting AD gives `nothing` when using CUDA
The following snippet works well on CPU i.e., it gives the correct gradient but fails on GPU.
using Zygote
y=complex.([4,1])
x=complex.([3,2])
function f1215(x, y)
x = 2 .* x
return sum(abs2.(x .- y))
end
gs = gradient(()-> f1215(x,y), Zygote.Params([x]))
gs[x] # returns nothing when x and y are on GPU
using CUDA
x = cu(x)
y = cu(y)
[Edited not to need Flux]
The gradient of broadcasting used for CuArrays doesn't handle complex numbers. It's a bit of a nasty surprise but nobody has got around to making it at least an error, or better to making it work.
Xref https://github.com/FluxML/Zygote.jl/issues/961, #1121 among others.
Edit: this thread has a similar problem from CuArray{SVector}.
closed in #1324