SASRec.pytorch icon indicating copy to clipboard operation
SASRec.pytorch copied to clipboard

why is the attention_mask's shape (tl, tl)

Open rabbicat30 opened this issue 2 years ago • 2 comments

tl = seqs.shape[1]  # time dim len for enforce causality
    
attention_mask = ~torch.tril(torch.ones((tl, tl), dtype=torch.bool, device=self.dev))

I can't understand why the attention_mask is this shape. Can you give me an answer or some references? I would be very grateful for your help!

rabbicat30 avatar Feb 15 '23 12:02 rabbicat30

You should look at the original Transformer paper and other blog posts (e.g., The Illustrated Transformer is great) for some more information. The reason is because in self-attention we're performing attention on a tensor with itself, hence the square shape.

seanswyi avatar Feb 17 '23 01:02 seanswyi

I know it. Thanks very much!

You should look at the original Transformer paper and other blog posts (e.g., The Illustrated Transformer is great) for some more information. The reason is because in self-attention we're performing attention on a tensor with itself, hence the square shape.

rabbicat30 avatar Feb 21 '23 07:02 rabbicat30