Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

gradient returns nothing for `sum(abs2, x)` with a complex CuArray

Open LexaLutyi opened this issue 3 years ago • 9 comments

gradient returns nothing for CuArrays{ComplexF32}, but works fine with Arrays{ComplexF32}:

julia> a = CUDA.rand(ComplexF32, 2)
2-element CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}:
  0.040781975f0 + 0.31357694f0im
 0.0050712824f0 + 0.464033f0im

julia> gradient(t -> sum(abs2, t), a)
(nothing,)

julia> a = rand(ComplexF32, 2)
2-element Vector{ComplexF32}:
 0.6999197f0 + 0.343145f0im
 0.4877541f0 + 0.47177994f0im

julia> gradient(t -> sum(abs2, t), a)
(ComplexF32[1.3998394f0 + 0.68629f0im, 0.9755082f0 + 0.9435599f0im],)

real and imag behave similarly to abs2

LexaLutyi avatar Nov 18 '21 11:11 LexaLutyi

Definitely a bug! Could you check on older Zygote releases, say v0.6.3?

DhairyaLGandhi avatar Nov 18 '21 11:11 DhairyaLGandhi

Definitely a bug! Could you check on older Zygote releases, say v0.6.3?

Yes, on version v0.6.3 it works correctly

julia> a = CUDA.rand(ComplexF32, 2)
2-element CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}:
 0.12610757f0 + 0.47263396f0im
  0.4873352f0 + 0.42900836f0im

julia> gradient(t -> sum(abs2, t), a)
(ComplexF32[0.25221515f0 + 0.9452679f0im, 0.9746704f0 + 0.8580167f0im],)

julia> b = Array(a)
2-element Vector{ComplexF32}:
 0.12610757f0 + 0.47263396f0im
  0.4873352f0 + 0.42900836f0im

julia> gradient(t -> sum(abs2, t), b)
(ComplexF32[0.25221515f0 + 0.9452679f0im, 0.9746704f0 + 0.8580167f0im],)

LexaLutyi avatar Nov 18 '21 12:11 LexaLutyi

I wonder if it's related to the recent projection related issues too.

DhairyaLGandhi avatar Nov 18 '21 12:11 DhairyaLGandhi

The problem occurs when using Flux, where Zygote is v0.6.30. If you install Zygote without Flux, then Zygote is v0.6.12 and everything works correctly.

LexaLutyi avatar Nov 18 '21 12:11 LexaLutyi

This used to go here: https://github.com/FluxML/Zygote.jl/blob/v0.6.12/src/lib/array.jl#L299

After #990 and #1004 it goes here, which calls the adjoint for broadcasting: https://github.com/FluxML/Zygote.jl/blob/master/src/lib/broadcast.jl#L278-L283

And that won't work, because broadcasting doesn't handle complex CuArrays at all, it treats them as if they are non-differentiable: https://github.com/FluxML/Zygote.jl/blob/master/src/lib/broadcast.jl#L224-L226

You could fix sum(abs2, x) by making it avoid broadcasting (either copying the explicit rule, or sending it to ChainRules). Or by adding a rule for @adjoint broadcasted(::typeof(abs2), x::Numeric), which wouldn't be a bad idea anyway. It would be easily make broadcasting over complex CuArrays an error instead of a silent wrong answer, and should also be done anyway. Finally, it would also not be very hard to make broadcasting over complex CuArrays just work instead.

Xref #961. No relation to projection.

mcabbott avatar Nov 18 '21 20:11 mcabbott

Good shout on the broadcasting changes. It's hard to say how some functions may silently break, the answer is probably to make sure the non broadcasted adjoint is hit in this case anyway. It would be tedious to have two adjoint rules to forward it manually, maybe we should take a look at what functions are broken due to the changes in the map adjoint.

DhairyaLGandhi avatar Nov 19 '21 06:11 DhairyaLGandhi

Bumping this: Is there any update on this issue other than to use older Zygote and Flux versions pre-broadcasting changes?

alexjaffray avatar Dec 06 '21 21:12 alexjaffray

This is causing problems for me as well. Zygote gradients of other things like f_exp(x) = sum(real(exp.(x))) also return nothing for complex CuArrays, similar to issue #961. That makes me hesitant to go down the custom adjoint rabbit hole. Is there hope for a fix that will cover these cases generally? Thanks for all your hard work - really love Zygote and CUDA.jl!

omalled avatar Jan 19 '22 00:01 omalled

Broadcasting with complex numbers has never worked on the GPU, sadly, this hasn't changed. It should be an error but isn't. It could certainly be made to work, but someone has to do it. A few special cases could more easily be made to work, too; I guess .+ and .* probably already do. And the headline issue is the special case sum(abs2, x).

mcabbott avatar Jan 19 '22 01:01 mcabbott

closed in #1324

CarloLucibello avatar Jan 10 '23 17:01 CarloLucibello