Flux.jl
Flux.jl copied to clipboard
Make `outputsize` work with `Embedding`
Like https://github.com/FluxML/Flux.jl/pull/1656 this wants to make outputsize(Embedding(3 => 4), (5,)) == (4, 5). That is, it thinks the size referred to by outputsize should be the size of the array of vocabulary indices, not the size of the one-hot representation.
But rather than overload indexing or gather (as here https://github.com/FluxML/Flux.jl/pull/1656/files#diff-0dfa3b94337acdaa714025f5198f6907e6a50a59aac03ba1230fbcb681126da2R172) this just adds methods to (::Embedding). I think that's least likely to cause surprises. If indexing shows up elsewhere, we can decide then whether to extend.
Restricting (m::Embedding)(x::AbstractArray{<:Integer}) also seems like the right thing to do, error right away on non-integer input.
Would this landing obsolete https://github.com/FluxML/Flux.jl/pull/1656?