NNlib.jl icon indicating copy to clipboard operation
NNlib.jl copied to clipboard

[Feature Request] Support index padding in `scatter`/`gather`

Open chengchingwen opened this issue 3 years ago • 2 comments

It would be convenient to support index padding in scatter / gather which ignore the index with specific padding value (0 for example).

chengchingwen avatar Oct 02 '21 06:10 chengchingwen

I'm not sure I understand. Can you make an example or define more precisely the expected behavior?

CarloLucibello avatar Oct 02 '21 07:10 CarloLucibello

Like this:

# sample 1
emb1 = randn(10, 5);
idx1 = [1,2,3,3,4];
# sample 2
emb2 = randn(10, 7);
idx2 = [1,2,2,3,3,3,4];
# batched sample
emb = cat(hcat(emb1, zeros(10, 2)), emb2; dims=3);
idx = hcat(vcat(map(x->CartesianIndex(x, 1), idx1), zeros(CartesianIndex{2}, 2)), map(x->CartesianIndex(x, 2), idx2));

# requested feature
y = NNlib.scatter(mean, emb, idx; dstsize = (10, 7, 2))

# expected behavior ( == y)
cat(NNlib.scatter(mean, emb1, idx1; dstsize = (10, 7)), NNlib.scatter(mean, emb2, idx2; dstsize = (10, 7)); dims=3)

chengchingwen avatar Oct 02 '21 07:10 chengchingwen