Flux.jl icon indicating copy to clipboard operation
Flux.jl copied to clipboard

update Embedding layer

Open manikyabard opened this issue 3 years ago • 21 comments

Updates the Embedding layer to use gather for AbstractVector which earlier had an issue with repeated indices. Also adds special case for outputsize.

manikyabard avatar Jul 09 '21 12:07 manikyabard

Needs tests

DhairyaLGandhi avatar Jul 09 '21 12:07 DhairyaLGandhi

ref. #1516

CarloLucibello avatar Jul 09 '21 13:07 CarloLucibello

I rebase the ref. #1516 on master since there were some conflicts. So I guess you should do git rebase -i origin/cl/embed here

CarloLucibello avatar Jul 09 '21 16:07 CarloLucibello

Or maybe it is just easier if I merge #1516 and then this targets master

CarloLucibello avatar Jul 09 '21 16:07 CarloLucibello

Or maybe it is just easier if I merge #1516 and then this targets master

Yeah sure that works

manikyabard avatar Jul 09 '21 16:07 manikyabard

Maybe labels other than integers starting at 1 deserve some thought. Flux's one-hot encoding lets you specify the set of labels, which are not stored, onecold returns integers starting at 1.

julia> Flux.onehotbatch(collect("hello"), 'a':'z') 
26×5 Flux.OneHotArray{26,2,Vector{UInt32}}:
 0  0  0  0  0
 0  0  0  0  0
 0  0  0  0  0
 0  0  0  0  0
 0  1  0  0  0
 0  0  0  0  0
 0  0  0  0  0
 1  0  0  0  0
 0  0  0  0  0
 0  0  0  0  0
 ⋮           

julia> Flux.onecold(ans) # does not remember
5-element Vector{Int64}:
  8
  5
 12
 12
 15

Should this do something similar? Then m = Embedding(0:9 => 5) would have to store something, so that m(0) can give the first vector. Or is this unnecessary complication? Better handled by composing onehot and Embedding?

mcabbott avatar Jul 10 '21 22:07 mcabbott

Mapping to an indices or one-hot space is a standard data transformation for using a layer like Embedding. So I would say it's not a necessary feature for the layer. And we would always need to support 1:N which could get tricky if the "labels" are also an integer range.

If we did want to include it, I would store a "transform" function within the layer.

PS: we do support onecold(x, labels)

darsnack avatar Jul 10 '21 23:07 darsnack

Ok, indeed this should probably be written something like Chain(onehot('a':'z'), Embedding(26 => 5)) rather than duplicating this.

mcabbott avatar Jul 11 '21 00:07 mcabbott

Should this be re-targeted for master?

darsnack avatar Jul 14 '21 13:07 darsnack

yes. Probably filing a new PR is easier

CarloLucibello avatar Jul 14 '21 13:07 CarloLucibello

I changed the base branch to master.

manikyabard avatar Jul 14 '21 13:07 manikyabard

Looks like you need a rebase too

darsnack avatar Jul 14 '21 13:07 darsnack

@manikyabard are you still up for rebasing this and moving forward? We ought to close up this loose end.

darsnack avatar Jan 27 '22 21:01 darsnack

@manikyabard are you still up for rebasing this and moving forward? We ought to close up this loose end.

Yeah I can continue working on this, although I am not sure about the approach we should take for outputsize. Maybe we can discuss this further in the next community call.

manikyabard avatar Jan 28 '22 09:01 manikyabard

Maybe we can discuss this further in the next community call.

Yeah, sounds good!

darsnack avatar Jan 28 '22 13:01 darsnack

Summarizing what was discussed on the call:

  • the issue is that the paths for AbstractVector{<:Integer} that do getindex calls don't make sense since Vector{Nil} is not a vector of indices
  • Michael's PR to make Nil a subtype of Real can help avoid the cases where Nil should only be hitting the AbstractArray{<:Real} paths
  • it still won't solve the issue when outputsize is used for a model utilizing the AbstractVector{<:Integer} path
    • this will need an outputsize override rule for Embedding + AbstractVector{Nil}
    • (this wasn't brought up during call but I just thought of it) a better solution would be to define NNlib.gather for Nil which will cover all indexing cases beyond just Embedding

darsnack avatar Feb 01 '22 18:02 darsnack

  • (this wasn't brought up during call but I just thought of it) a better solution would be to define NNlib.gather for Nil which will cover all indexing cases beyond just Embedding

You mean something like this?

NNlib.gather!(dst::AbstractArray, ::AbstractArray, ::AbstractArray{<:Nil}) = fill(nil, size(dst)...)
(m::Embedding)(x::AbstractVector{<:Nil}) = NNlib.gather(m.weight, x)

manikyabard avatar Feb 01 '22 19:02 manikyabard

That looks right but you won't need to special case for Embedding anymore. It should go through the NNlib.gather rule automatically.

darsnack avatar Feb 01 '22 19:02 darsnack

this needs a rebase in master

CarloLucibello avatar Feb 13 '22 16:02 CarloLucibello

Codecov Report

Merging #1656 (de43bf5) into master (3cc9067) will decrease coverage by 0.07%. The diff coverage is 60.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1656      +/-   ##
==========================================
- Coverage   84.58%   84.51%   -0.08%     
==========================================
  Files          21       21              
  Lines        1486     1485       -1     
==========================================
- Hits         1257     1255       -2     
- Misses        229      230       +1     
Impacted Files Coverage Δ
src/onehot.jl 95.29% <ø> (-0.06%) :arrow_down:
src/outputsize.jl 82.05% <0.00%> (-2.16%) :arrow_down:
src/layers/basic.jl 80.99% <75.00%> (-0.16%) :arrow_down:

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 3cc9067...de43bf5. Read the comment docs.

codecov-commenter avatar Feb 13 '22 17:02 codecov-commenter

Can you add a test for outputsize of Embedding?

darsnack avatar Feb 14 '22 23:02 darsnack