torchrec
torchrec copied to clipboard
EmbeddingCollection+KeyedJaggedTensor+vbe the inverse_indices don't work
import torch
from torchrec import KeyedJaggedTensor
from torchrec import EmbeddingBagConfig,EmbeddingConfig
from torchrec import EmbeddingBagCollection,EmbeddingCollection
kt = KeyedJaggedTensor(
keys=['t1', 't2'],
values=torch.tensor([0,0,0,0,2]),
lengths=torch.tensor([1,1,1,1,0,1], dtype=torch.int64),
)
kt2 = KeyedJaggedTensor(
keys=['t1', 't2'],
values=torch.tensor([0,0,2]),
lengths=torch.tensor([1,1,0,1], dtype=torch.int64),
stride_per_key_per_rank=[[1], [3]],
inverse_indices=(['t1', 't2'], torch.tensor([[0,0,0], [0,1,2]]))
)
eb_configs = [
EmbeddingBagConfig(
num_embeddings=100,
embedding_dim=16,
name='e1',
feature_names=['t1']
),
EmbeddingBagConfig(
num_embeddings=100,
embedding_dim=16,
name='e2',
feature_names=['t2']
)
]
ebc = EmbeddingBagCollection(eb_configs)
print(ebc(kt)['t1'])
print(ebc(kt2)['t1'])
eb_configs = [
EmbeddingConfig(
num_embeddings=100,
embedding_dim=16,
name='e1',
feature_names=['t1']
),
EmbeddingConfig(
num_embeddings=100,
embedding_dim=16,
name='e2',
feature_names=['t2']
)
]
ebc = EmbeddingCollection(eb_configs)
print(ebc(kt)["t1"].lengths().size())
print(ebc(kt2)["t1"].lengths().size())
结果: EmbeddingCollection 之后的结果没有根据inverse_indices 进行重新排列,长度为3,1
ccn @joshuadeng
hi @yjjinjie, currently EmbeddingCollection does not support variable batch size per feature here. This work is being planned so stay tuned.