Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

gradient of array splat should be an array

Open CarloLucibello opened this issue 5 years ago • 2 comments

The last example, behaves differently from the first two

julia> gradient(w -> sum(w), [1,1])  # ok, gradient is array
([1, 1],)   

julia> gradient(w -> sum([w[1], w[2]]), [1,1]) # ok, gradient is array
([1, 1],)

julia> gradient(w -> sum([w...]), [1,1]) # NOT OK, gradient is tuple
((1, 1),)

One of the problems with returning a tuple, is that it breaks Flux's update! function, which is expecting an array

CarloLucibello avatar Apr 14 '20 11:04 CarloLucibello

See #489, where @MikeInnes described the approach for fixing this.

AzamatB avatar Apr 14 '20 11:04 AzamatB

BTW this is not fixed, although ProjectTo on final results hides the problem in the example above. If the gradient needs to be passed to another pullback expecting an array, you will get errors.

To avoid the projection & see the problem, you can do this:

julia> pullback(w -> sum([w...]), [1,1])[2](1.0)
((1.0, 1.0),)

julia> pullback(w -> sum([w...]), [1 2; 3 4])[2](1.0)
((1.0, 1.0, 1.0, 1.0),)

mcabbott avatar Jul 22 '22 14:07 mcabbott