torchrec icon indicating copy to clipboard operation
torchrec copied to clipboard

EmbeddingCollection+KeyedJaggedTensor+vbe the inverse_indices don't work

Open yjjinjie opened this issue 1 year ago • 2 comments



import torch
from torchrec import KeyedJaggedTensor
from torchrec import EmbeddingBagConfig,EmbeddingConfig
from torchrec import EmbeddingBagCollection,EmbeddingCollection


kt = KeyedJaggedTensor(
    keys=['t1', 't2'],
    values=torch.tensor([0,0,0,0,2]),
    lengths=torch.tensor([1,1,1,1,0,1], dtype=torch.int64),
)


kt2 = KeyedJaggedTensor(
    keys=['t1', 't2'],
    values=torch.tensor([0,0,2]),
    lengths=torch.tensor([1,1,0,1], dtype=torch.int64),
    stride_per_key_per_rank=[[1], [3]],
    inverse_indices=(['t1', 't2'], torch.tensor([[0,0,0], [0,1,2]]))
)

eb_configs = [
    EmbeddingBagConfig(
        num_embeddings=100,
        embedding_dim=16,
        name='e1',
        feature_names=['t1']
    ),
    EmbeddingBagConfig(
        num_embeddings=100,
        embedding_dim=16,
        name='e2',
        feature_names=['t2']
    )
]

ebc = EmbeddingBagCollection(eb_configs)
print(ebc(kt)['t1'])
print(ebc(kt2)['t1'])



eb_configs = [
    EmbeddingConfig(
        num_embeddings=100,
        embedding_dim=16,
        name='e1',
        feature_names=['t1']
    ),
    EmbeddingConfig(
        num_embeddings=100,
        embedding_dim=16,
        name='e2',
        feature_names=['t2']
    )
]

ebc = EmbeddingCollection(eb_configs)


print(ebc(kt)["t1"].lengths().size())
print(ebc(kt2)["t1"].lengths().size())

结果: EmbeddingCollection 之后的结果没有根据inverse_indices 进行重新排列,长度为3,1

yjjinjie avatar Apr 18 '24 09:04 yjjinjie

ccn @joshuadeng

colin2328 avatar May 13 '24 23:05 colin2328

hi @yjjinjie, currently EmbeddingCollection does not support variable batch size per feature here. This work is being planned so stay tuned.

joshuadeng avatar May 14 '24 05:05 joshuadeng