Zygote.jl
Zygote.jl copied to clipboard
gradient of array splat should be an array
The last example, behaves differently from the first two
julia> gradient(w -> sum(w), [1,1]) # ok, gradient is array
([1, 1],)
julia> gradient(w -> sum([w[1], w[2]]), [1,1]) # ok, gradient is array
([1, 1],)
julia> gradient(w -> sum([w...]), [1,1]) # NOT OK, gradient is tuple
((1, 1),)
One of the problems with returning a tuple, is that it breaks Flux's update! function, which is expecting an array
See #489, where @MikeInnes described the approach for fixing this.
BTW this is not fixed, although ProjectTo on final results hides the problem in the example above. If the gradient needs to be passed to another pullback expecting an array, you will get errors.
To avoid the projection & see the problem, you can do this:
julia> pullback(w -> sum([w...]), [1,1])[2](1.0)
((1.0, 1.0),)
julia> pullback(w -> sum([w...]), [1 2; 3 4])[2](1.0)
((1.0, 1.0, 1.0, 1.0),)