NNlib.jl
NNlib.jl copied to clipboard
[Feature Request] Support index padding in `scatter`/`gather`
It would be convenient to support index padding in scatter
/ gather
which ignore the index with specific padding value (0
for example).
I'm not sure I understand. Can you make an example or define more precisely the expected behavior?
Like this:
# sample 1
emb1 = randn(10, 5);
idx1 = [1,2,3,3,4];
# sample 2
emb2 = randn(10, 7);
idx2 = [1,2,2,3,3,3,4];
# batched sample
emb = cat(hcat(emb1, zeros(10, 2)), emb2; dims=3);
idx = hcat(vcat(map(x->CartesianIndex(x, 1), idx1), zeros(CartesianIndex{2}, 2)), map(x->CartesianIndex(x, 2), idx2));
# requested feature
y = NNlib.scatter(mean, emb, idx; dstsize = (10, 7, 2))
# expected behavior ( == y)
cat(NNlib.scatter(mean, emb1, idx1; dstsize = (10, 7)), NNlib.scatter(mean, emb2, idx2; dstsize = (10, 7)); dims=3)