open_clip
open_clip copied to clipboard
build_cls_mask() in CoCa TextTransfotmer
TL, DR: current implementation of build_cls_mask()
produces cls_mask
for [CLS] being as the first token. But in CoCa, [CLS] is the end token.
In Issue 312, build_cls_mask()
was introduced by @gpucce in TextTransformer
in CoCa to "preventing the CLS token at the end of the sequence from attending to padded tokens".
# https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py#L587
def build_cls_mask(self, text, cast_dtype: torch.dtype):
cls_mask = (text != self.pad_id).unsqueeze(1)
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0)
additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
additive_mask.fill_(0)
additive_mask.masked_fill_(~cls_mask, float("-inf"))
additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
return additive_mask
Taking text = torch.tensor([[1,2,3,4,0,0,0]])
as an example,
import torch
import torch.nn.functional as F
text = torch.tensor([[1,2,3,4,0,0,0]]) ### batch size 1, sequence 4 with 3 padding (pad_id=0)
pad_id = 0
cls_mask = (text != pad_id).unsqueeze(1)
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0)
additive_mask = torch.empty(cls_mask.shape)
additive_mask.fill_(0)
additive_mask.masked_fill_(~cls_mask, float("-inf"))
print(additive_mask)
This output
tensor([[[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., -inf, -inf, -inf]]])
In @lucidrains implementation
# https://github.com/lucidrains/CoCa-pytorch/blob/main/coca_pytorch/coca_pytorch.py#L384-L385
cls_mask = rearrange(text!=self.pad_id, 'b j -> b 1 j')
attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)
taking the same text
as the example
import einops
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
text = torch.tensor([[1,2,3,4,0,0,0]])
pad_id = 0
seq = text.shape[1]
cls_mask = rearrange(text!=pad_id, 'b j -> b 1 j')
attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)
print(attn_mask)
it produces (which I believe should be the desired outcome)
tensor([[[ True, True, True, True, True, True, True, True],
[ True, True, True, True, True, True, True, True],
[ True, True, True, True, True, True, True, True],
[ True, True, True, True, True, True, True, True],
[ True, True, True, True, True, True, True, True],
[ True, True, True, True, True, True, True, True],
[ True, True, True, True, True, True, True, True],
[ True, True, True, True, False, False, False, True]]])
Since [CLS] token is appended at the end of a sequence,
# https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py#L607
x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1)
I feel that the current implementation in open_clip
is wrong? Do I miss anything?
Yes, I feel that the current implementation is also wrong, can someone give update on this?
Hi, There is a PR #551 to fix this but I think nobody has time to review it