FlagEmbedding
FlagEmbedding copied to clipboard
Using scatter_reduce instead of scatter and max
Thank you for sharing your outstanding work
Using scatter_reduce instead of scatter allows you to create a tensor of shape (bs, vocab_size) instead of (bs, length, vocab_size), which reduces memory usage. This means you can use a larger batch size. How about using scatter_reduce?
https://github.com/FlagOpen/FlagEmbedding/blob/fcdf889ec91edcd5278ba33a08c0665f4a59feb6/research/BGE_M3/modeling.py#L106
https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_reduce_.html#torch.Tensor.scatter_reduce_
sparse_embedding = torch.zeros(input_ids.size(0), self.vocab_size,
dtype=token_weights.dtype,
device=token_weights.device)
sparse_embedding = sparse_embedding.scatter_reduce(dim=-1, index=input_ids, src=token_weights.squeeze(-1), reduce='amax')
Hello, @lsrock1! Thanks for your insightful suggestion. I've modify the corresponding code in this PR: #1423