open_clip icon indicating copy to clipboard operation
open_clip copied to clipboard

build_cls_mask() in CoCa TextTransfotmer

Open yiren-jian opened this issue 1 year ago • 2 comments

TL, DR: current implementation of build_cls_mask() produces cls_mask for [CLS] being as the first token. But in CoCa, [CLS] is the end token.

In Issue 312, build_cls_mask() was introduced by @gpucce in TextTransformer in CoCa to "preventing the CLS token at the end of the sequence from attending to padded tokens".

# https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py#L587
def build_cls_mask(self, text, cast_dtype: torch.dtype):
        cls_mask = (text != self.pad_id).unsqueeze(1)
        cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0)
        additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
        additive_mask.fill_(0)
        additive_mask.masked_fill_(~cls_mask, float("-inf"))
        additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
        return additive_mask

Taking text = torch.tensor([[1,2,3,4,0,0,0]]) as an example,

import torch
import torch.nn.functional as F

text = torch.tensor([[1,2,3,4,0,0,0]])  ### batch size 1, sequence 4 with 3 padding (pad_id=0)

pad_id = 0
cls_mask = (text != pad_id).unsqueeze(1)
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0)
additive_mask = torch.empty(cls_mask.shape)
additive_mask.fill_(0)
additive_mask.masked_fill_(~cls_mask, float("-inf"))
print(additive_mask)

This output

tensor([[[0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., -inf, -inf, -inf]]])

In @lucidrains implementation

# https://github.com/lucidrains/CoCa-pytorch/blob/main/coca_pytorch/coca_pytorch.py#L384-L385
cls_mask = rearrange(text!=self.pad_id, 'b j -> b 1 j')
attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)

taking the same text as the example

import einops
import torch
import torch.nn.functional as F
from einops import rearrange, repeat

text = torch.tensor([[1,2,3,4,0,0,0]])

pad_id = 0
seq = text.shape[1]

cls_mask = rearrange(text!=pad_id, 'b j -> b 1 j')
attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)
print(attn_mask)

it produces (which I believe should be the desired outcome)

tensor([[[ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True, False, False, False,  True]]])

Since [CLS] token is appended at the end of a sequence,

# https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py#L607
x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1)

I feel that the current implementation in open_clip is wrong? Do I miss anything?

yiren-jian avatar Jun 24 '23 04:06 yiren-jian

Yes, I feel that the current implementation is also wrong, can someone give update on this?

Mypathissional avatar Sep 20 '23 17:09 Mypathissional

Hi, There is a PR #551 to fix this but I think nobody has time to review it

gpucce avatar Sep 20 '23 17:09 gpucce