transformer-xl icon indicating copy to clipboard operation
transformer-xl copied to clipboard

Relative Positional Encoding

Open LarsHill opened this issue 3 years ago • 1 comments

Hi,

I have a quick question with respect to the relative shift operation:

    def _rel_shift(self, x, zero_triu=False):
        zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
                               device=x.device, dtype=x.dtype)
        x_padded = torch.cat([zero_pad, x], dim=1)

        x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])

        x = x_padded[1:].view_as(x)

        if zero_triu:
            ones = torch.ones((x.size(0), x.size(1)))
            x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None]

        return x

In the transformer-xl paper, Appendix B (https://arxiv.org/pdf/1901.02860.pdf), we see that the upper right triangular of matrix B consists of zeros. In the above code and throughout the model implementation zero_triu == False so that after performing the relative shift, the upper right triangle is not filled with zeros as described in the paper. In the huggingface implementation of this function, this unused parameter is completely removed (see https://github.com/huggingface/transformers/blob/master/src/transformers/models/transfo_xl/modeling_transfo_xl.py#L275).

Is the upper right triangle masked at a later place no matter what, or why can zero_triu be neglected?

LarsHill avatar Sep 08 '21 11:09 LarsHill

I was confused about this too, but I think the attn_mask functionality applied here and initialized here does the job.

duvallj avatar Nov 08 '21 14:11 duvallj