Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Splatting gives an incorrect tuple return type on the gradient

Open ChrisRackauckas opened this issue 1 year ago • 1 comments

MWE:

using Flux
ann = Chain(Dense(5, 50, tanh), Dense(50, 4))
p, st = Flux.destructure(ann)
u0 = rand(5)

using Zygote
function dudt_(u, p, t)
    st(p)([u..., 1f1])
end
out, back = Zygote.pullback(dudt_, u0, p, 0f0)
d_u, d_p, d_t = back(rand(4))
typeof(d_u) # NTuple

d_u should be a vector but instead it becomes a tuple. This goes away if the function is instead:

function dudt_(u, p, t)
    st(p)([u;1f1])
end

This was isolated from https://github.com/SciML/SciMLSensitivity.jl/issues/1082

ChrisRackauckas avatar Sep 08 '24 05:09 ChrisRackauckas

Dup of #599 I think. There's a PR which could be revived.

mcabbott avatar Sep 09 '24 19:09 mcabbott