Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Complex broadcasting AD gives `nothing` when using CUDA

Open pawbz opened this issue 3 years ago • 1 comments

The following snippet works well on CPU i.e., it gives the correct gradient but fails on GPU.

using Zygote
y=complex.([4,1])
x=complex.([3,2])
function f1215(x, y) 
    x = 2 .* x
    return sum(abs2.(x .- y))
end 
gs = gradient(()-> f1215(x,y), Zygote.Params([x]))
gs[x] # returns nothing when x and y are on GPU

using CUDA
x = cu(x)
y = cu(y)

[Edited not to need Flux]

pawbz avatar Apr 28 '22 09:04 pawbz

The gradient of broadcasting used for CuArrays doesn't handle complex numbers. It's a bit of a nasty surprise but nobody has got around to making it at least an error, or better to making it work.

Xref https://github.com/FluxML/Zygote.jl/issues/961, #1121 among others.

Edit: this thread has a similar problem from CuArray{SVector}.

mcabbott avatar Apr 30 '22 15:04 mcabbott

closed in #1324

CarloLucibello avatar Jan 10 '23 17:01 CarloLucibello