Zygote.jl
Zygote.jl copied to clipboard
Change adjoint for _apply
Fixes #399. Proposed by @mohamed82008 in slack.
This could do with a minimal test case so it's clear what it's fixing.
Done
I think some of the tests assumed the previous behavior which is why they are failing. Disclaimer: I said on slack I am not sure this "fix" is even correct, but it does fix the issue of splatting arrays.
Ok, I understand what's going on here now. This is a tricky case and we need to think carefully about what the right fix is.
The problem generally is that we can splat non-tuple objects (arrays, sets, custom types etc) and this currently means they'll get a tuple gradient. This is potentially surprising and doesn't give a good error message if it goes wrong.
julia> gradient(x -> +(x...), [1, 2, 3])[1]
(1, 1, 1)
There are a few different options here:
- Define a way for custom types to participate in
_emptyand_unapply(similar to what you have here, but user-extensible and without closing over inputs). This has the downside that any user-defined iterable type will break if splatted, unless people hook into a somewhat obscure Zygote API. - Define tuples as valid gradients for iterables, in which case the fix is closer to #404. Still potentially surprising, especially in the REPL, but possibly more natural to support.
- Rewrite
Core._apply(f, xs)toCore._apply(f, collect_tuple(xs)). The idea would be to differentiate through thecollect_tuplecall, which uses iteration etc. as normal, meaning it will always produce a gradient of the correct type. This has the least fragile behaviour but potentially adds overhead (though possibly only in cases that would be slow anyway; it's pretty easy to make the tuple case fast).
PoC for the last option, which I'm leaning towards:
julia> function collect_tuple(xs, state...)
next = iterate(xs, state...)
next === nothing && return ()
return (next[1], collect_tuple(xs, next[2])...)
end
julia> gradient(x -> +(collect_tuple(x)...), [1, 2, 3])[1]
3-element Array{Int64,1}:
1
1
1
should this be revived in order to fix https://github.com/FluxML/Zygote.jl/issues/599?