Flux.jl icon indicating copy to clipboard operation
Flux.jl copied to clipboard

Make `outputsize` work with `Embedding`

Open mcabbott opened this issue 3 years ago • 1 comments
trafficstars

Like https://github.com/FluxML/Flux.jl/pull/1656 this wants to make outputsize(Embedding(3 => 4), (5,)) == (4, 5). That is, it thinks the size referred to by outputsize should be the size of the array of vocabulary indices, not the size of the one-hot representation.

But rather than overload indexing or gather (as here https://github.com/FluxML/Flux.jl/pull/1656/files#diff-0dfa3b94337acdaa714025f5198f6907e6a50a59aac03ba1230fbcb681126da2R172) this just adds methods to (::Embedding). I think that's least likely to cause surprises. If indexing shows up elsewhere, we can decide then whether to extend.

Restricting (m::Embedding)(x::AbstractArray{<:Integer}) also seems like the right thing to do, error right away on non-integer input.

mcabbott avatar Oct 16 '22 14:10 mcabbott

Would this landing obsolete https://github.com/FluxML/Flux.jl/pull/1656?

ToucheSir avatar Jan 05 '23 04:01 ToucheSir