Zygote.jl
Zygote.jl copied to clipboard
gradient returns nothing for `sum(abs2, x)` with a complex CuArray
gradient
returns nothing
for CuArrays{ComplexF32}
, but works fine with Arrays{ComplexF32}
:
julia> a = CUDA.rand(ComplexF32, 2)
2-element CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}:
0.040781975f0 + 0.31357694f0im
0.0050712824f0 + 0.464033f0im
julia> gradient(t -> sum(abs2, t), a)
(nothing,)
julia> a = rand(ComplexF32, 2)
2-element Vector{ComplexF32}:
0.6999197f0 + 0.343145f0im
0.4877541f0 + 0.47177994f0im
julia> gradient(t -> sum(abs2, t), a)
(ComplexF32[1.3998394f0 + 0.68629f0im, 0.9755082f0 + 0.9435599f0im],)
real
and imag
behave similarly to abs2
Definitely a bug! Could you check on older Zygote releases, say v0.6.3?
Definitely a bug! Could you check on older Zygote releases, say v0.6.3?
Yes, on version v0.6.3 it works correctly
julia> a = CUDA.rand(ComplexF32, 2)
2-element CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}:
0.12610757f0 + 0.47263396f0im
0.4873352f0 + 0.42900836f0im
julia> gradient(t -> sum(abs2, t), a)
(ComplexF32[0.25221515f0 + 0.9452679f0im, 0.9746704f0 + 0.8580167f0im],)
julia> b = Array(a)
2-element Vector{ComplexF32}:
0.12610757f0 + 0.47263396f0im
0.4873352f0 + 0.42900836f0im
julia> gradient(t -> sum(abs2, t), b)
(ComplexF32[0.25221515f0 + 0.9452679f0im, 0.9746704f0 + 0.8580167f0im],)
I wonder if it's related to the recent projection related issues too.
The problem occurs when using Flux, where Zygote is v0.6.30. If you install Zygote without Flux, then Zygote is v0.6.12 and everything works correctly.
This used to go here: https://github.com/FluxML/Zygote.jl/blob/v0.6.12/src/lib/array.jl#L299
After #990 and #1004 it goes here, which calls the adjoint for broadcasting: https://github.com/FluxML/Zygote.jl/blob/master/src/lib/broadcast.jl#L278-L283
And that won't work, because broadcasting doesn't handle complex CuArrays at all, it treats them as if they are non-differentiable: https://github.com/FluxML/Zygote.jl/blob/master/src/lib/broadcast.jl#L224-L226
You could fix sum(abs2, x)
by making it avoid broadcasting (either copying the explicit rule, or sending it to ChainRules). Or by adding a rule for @adjoint broadcasted(::typeof(abs2), x::Numeric)
, which wouldn't be a bad idea anyway. It would be easily make broadcasting over complex CuArrays an error instead of a silent wrong answer, and should also be done anyway. Finally, it would also not be very hard to make broadcasting over complex CuArrays just work instead.
Xref #961. No relation to projection.
Good shout on the broadcasting changes. It's hard to say how some functions may silently break, the answer is probably to make sure the non broadcasted adjoint is hit in this case anyway. It would be tedious to have two adjoint rules to forward it manually, maybe we should take a look at what functions are broken due to the changes in the map adjoint.
Bumping this: Is there any update on this issue other than to use older Zygote and Flux versions pre-broadcasting changes?
This is causing problems for me as well. Zygote gradients of other things like f_exp(x) = sum(real(exp.(x)))
also return nothing
for complex CuArray
s, similar to issue #961. That makes me hesitant to go down the custom adjoint rabbit hole. Is there hope for a fix that will cover these cases generally? Thanks for all your hard work - really love Zygote and CUDA.jl!
Broadcasting with complex numbers has never worked on the GPU, sadly, this hasn't changed. It should be an error but isn't. It could certainly be made to work, but someone has to do it. A few special cases could more easily be made to work, too; I guess .+
and .*
probably already do. And the headline issue is the special case sum(abs2, x)
.
closed in #1324