transformers icon indicating copy to clipboard operation
transformers copied to clipboard

self.offset=2 in Bart position_embedding

Open ShiyuNee opened this issue 1 year ago • 1 comments

System Info

I think the code in BartLearnedPositionalEmbedding is not the same as the code for BART pretraining. The offset = 2 is because we just want to keep with the pretrained results and offset=2 is not relevant to ids in positions

Who can help?

No response

Information

  • [X] The official example scripts
  • [ ] My own modified scripts

Tasks

  • [X] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [ ] My own task or dataset (give details below)

Reproduction

class BartLearnedPositionalEmbedding(nn.Embedding): """ This module learns positional embeddings up to a fixed maximum size. """

def __init__(self, num_embeddings: int, embedding_dim: int):
    # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
    # and adjust num_embeddings appropriately. Other models don't have this hack
    self.offset = 2
    super().__init__(num_embeddings + self.offset, embedding_dim)

def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
    """`input_ids' shape is expected to be [bsz x seqlen]."""

    bsz, seq_len = input_ids.shape[:2]
    positions = torch.arange(
        past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
    ).expand(bsz, -1) # [bsz, seq_len], 
    print(f"positions: {positions}")
    # print(f"positions + self.offset: {positions + self.offset}")

    # encoder.embed_positions.weight[1026, 768]
    return super().forward(positions + self.offset)

Expected behavior

There is no need for offset=2 because we use position_idx not the input_ids

ShiyuNee avatar Mar 24 '23 13:03 ShiyuNee

cc @ArthurZucker

amyeroberts avatar Mar 24 '23 14:03 amyeroberts

Hey, I am not sure I understand your question. Could you clarify what you mean by we just want to keep with the pretrained results and offset=2 is not relevant to ids in positions?

ArthurZucker avatar Mar 27 '23 11:03 ArthurZucker

Here is an answer :

The reason why we are not using

>>> embed_tokens = nn.embedding(vocab_dim, hidden_dim, padding_idx)

Is that this makes the positions at index padding_idx un-learnable , and it zeros them out.

What if you change the padding index to something bigger? Let’s say 4 then the embedding at index 4 will be zeroed out ( basically erased ) but for the model, that means that when it will never receive the embedding that should be at position 4 ( which is position 6 now). The offset prevents that.

→ Potential usage: Imagine if you need a new starting token in your BartModel. The padding token will no longer be 2 but 3. This means you just want to shift the inputs learned positions by 1, not that you want to zero-out the learned position embedding at position 3. The position embedding for 3 will appear as if it was 4.

Snippet:

# during training
>>> input_ids = [  3, 13, 25, 1, 1 ,1 ,1]
>>> pad_token_id = 1
>>> positions = [  0,  1,  2,  3,  4,  5,  6]
>>> pw_offset = [  2,  3,  4,  5,  6,  7,  8]  
>>> embedding = [ X2, X3, X4, X5, X6, X7, X8] 

# finetuning with one more token
>>> new_pad_token_id = 4 # but the position of the padding token is not necessarly 2
>>> input_ids = [  1,    2, 13,  25,  1,  1,  1,  1]
>>> positions = [  0,    1,  2,   3,  4,  5,  6,  7]
>>> pw_offset = [  2,    3,  4,   5,  6,  7,  8,  9] 
>>> embedding = [ X2,   X3,  0,  X5, X6,  X7, X8, X9]

# With the code fix:
# finetuning with one more token
>>> new_pad_token_id = 4 # but the position of the padding token is not necessarly 2
>>> input_ids = [  1,    2, 13,  25,  1,  1,  1,  1]
>>> positions = [  0,    1,  2,   3,  4,  5,  6,  7]
>>> pw_offset = [  2,    3,  4,   5,  6,  7,  8,  9] 
>>> embedding = [  X2,   X3, X4,  X5, X6,  X7, X8, X9] 

If you zero-out the embeddings corresponding to the index of the padding token, changing the ID of the padding token will result in a change of the inputs that are positioned at this index.

The subtil difference is that it does not matter if your padding token has index 0, 1, or 999.

The tokens that are at the position of the index ( let’s say the 999th token) should not have a zeroed-out embedding. But, if the token at that position is a padding token, then the attention should take it into account.

If we zero out at index 4, the 4th token will never have a learned positional embedding. Longer thread and infos in #19240

ArthurZucker avatar Apr 13 '23 14:04 ArthurZucker