torchrec icon indicating copy to clipboard operation
torchrec copied to clipboard

How to share embeddings of some features between two EmbeddingBagCollections?

Open tiankongdeguiji opened this issue 10 months ago • 6 comments

In a two-tower retrieval model, it is essential to randomly sample negative items. Typically, this means that the batch size for the item tower will be larger than that for the user tower. Consequently, using a single EmbeddingBagCollection proves to be inadequate for this setup. When employing two separate EmbeddingBagCollections, how to share embeddings of some features between two EmbeddingBagCollections?

tiankongdeguiji avatar Apr 23 '24 03:04 tiankongdeguiji

Hi, @henrylhtsang @IvanKobzarev @joshuadeng @PaulZhang12 can you see this problem?

tiankongdeguiji avatar Apr 23 '24 03:04 tiankongdeguiji

  • [not recommended] use padding so they have same length
  • use VBE in kjt

henrylhtsang avatar Apr 24 '24 22:04 henrylhtsang

  • [not recommended] use padding so they have same length
  • use VBE in kjt

@henrylhtsang If we use VBE, the output of the user tower and item tower will also be padded to the same batch size. Is this approach efficient?

tiankongdeguiji avatar Apr 25 '24 01:04 tiankongdeguiji

@tiankongdeguiji oh I misspoke. I meant to say you either need to pad the inputs, or use VBE. iirc you shouldn't need to pad the outputs of VBE, but admittedly I am not familiar with that part

henrylhtsang avatar Apr 25 '24 04:04 henrylhtsang

@tiankongdeguiji oh I misspoke. I meant to say you either need to pad the inputs, or use VBE. iirc you shouldn't need to pad the outputs of VBE, but admittedly I am not familiar with that part

@henrylhtsang for example, if batch_size of user_tower is 2, batch_size of item_tower is 4. user tower and item tower do not need share embedding. we could create

kt_u = KeyedJaggedTensor(
    keys=['user_f'],
    values=torch.tensor([0,1]),
    lengths=torch.tensor([1,1], dtype=torch.int64),
)
kt_i = KeyedJaggedTensor(
    keys=['item_f'],
    values=torch.tensor([2,3,4,5]),
    lengths=torch.tensor([1,1,1,1], dtype=torch.int64),
)
eb_config_u = [
    EmbeddingBagConfig(
        num_embeddings=100,
        embedding_dim=16,
        name='e1',
        feature_names=['user_f']
    )
]
eb_config_i = [
    EmbeddingBagConfig(
        num_embeddings=100,
        embedding_dim=16,
        name='e2',
        feature_names=['item_f']
    )
]
ebc_u = EmbeddingBagCollection(eb_config_u)
ebc_i = EmbeddingBagCollection(eb_config_i)
print('user:', ebc_u(kt_u).values().shape)
print('item:', ebc_i(kt_i).values().shape)
user: torch.Size([2, 16])
item: torch.Size([4, 16])

If we use VBE to implement share-embedding, the output of the user tower and item tower will be padded to the same batch_size.

kt = KeyedJaggedTensor(
    keys=['user_f', 'item_f'],
    values=torch.tensor([0,1,2,3,4,5]),
    lengths=torch.tensor([1,1,1,1,1,1], dtype=torch.int64),
    stride_per_key_per_rank=[[2], [4]],
    inverse_indices=(['user_f', 'item_f'], torch.tensor([[0,1,1,1], [0,1,2,3]]))

)
eb_configs = [
    EmbeddingBagConfig(
        num_embeddings=100,
        embedding_dim=16,
        name='e1',
        feature_names=['user_f', 'item_f']
    )
]
ebc = EmbeddingBagCollection(eb_configs)
print('user+item:', ebc(kt).values().shape)
user+item: torch.Size([4, 32])

batch_size of user tower is 4 rather than 2.

tiankongdeguiji avatar Apr 25 '24 12:04 tiankongdeguiji

not an expert, but can you try the sharded version of ebc? not sure if the unsharded ebc supports VBE very well

henrylhtsang avatar May 03 '24 19:05 henrylhtsang