Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Change adjoint for _apply

Open AzamatB opened this issue 5 years ago • 5 comments

Fixes #399. Proposed by @mohamed82008 in slack.

AzamatB avatar Jan 30 '20 13:01 AzamatB

This could do with a minimal test case so it's clear what it's fixing.

MikeInnes avatar Jan 30 '20 14:01 MikeInnes

Done

AzamatB avatar Jan 30 '20 16:01 AzamatB

I think some of the tests assumed the previous behavior which is why they are failing. Disclaimer: I said on slack I am not sure this "fix" is even correct, but it does fix the issue of splatting arrays.

mohdibntarek avatar Jan 31 '20 20:01 mohdibntarek

Ok, I understand what's going on here now. This is a tricky case and we need to think carefully about what the right fix is.

The problem generally is that we can splat non-tuple objects (arrays, sets, custom types etc) and this currently means they'll get a tuple gradient. This is potentially surprising and doesn't give a good error message if it goes wrong.

julia> gradient(x -> +(x...), [1, 2, 3])[1]
(1, 1, 1)

There are a few different options here:

  1. Define a way for custom types to participate in _empty and _unapply (similar to what you have here, but user-extensible and without closing over inputs). This has the downside that any user-defined iterable type will break if splatted, unless people hook into a somewhat obscure Zygote API.
  2. Define tuples as valid gradients for iterables, in which case the fix is closer to #404. Still potentially surprising, especially in the REPL, but possibly more natural to support.
  3. Rewrite Core._apply(f, xs) to Core._apply(f, collect_tuple(xs)). The idea would be to differentiate through the collect_tuple call, which uses iteration etc. as normal, meaning it will always produce a gradient of the correct type. This has the least fragile behaviour but potentially adds overhead (though possibly only in cases that would be slow anyway; it's pretty easy to make the tuple case fast).

PoC for the last option, which I'm leaning towards:

julia> function collect_tuple(xs, state...)
         next = iterate(xs, state...)
         next === nothing && return ()
         return (next[1], collect_tuple(xs, next[2])...)
       end

julia> gradient(x -> +(collect_tuple(x)...), [1, 2, 3])[1]
3-element Array{Int64,1}:
 1
 1
 1

MikeInnes avatar Feb 04 '20 16:02 MikeInnes

should this be revived in order to fix https://github.com/FluxML/Zygote.jl/issues/599?

CarloLucibello avatar Jan 24 '23 06:01 CarloLucibello