xformers
xformers copied to clipboard
incorrect causal mask in global attention
I am using GlobalAttention and got abnormally low loss even when I set causal to True. Upon inspection, I found the causal mask is not applied at all.
Here the shape[1] will always be 1 given the assert?
https://github.com/facebookresearch/xformers/blob/68b7fd14df5eb1d2558c52842b4206a14d2d20e9/xformers/components/attention/global_tokens.py#L73
After changing to shape[0] the loss is more reasonable.
This has been around for quite a while and no one seems to report any issue. Can anybody confirm?
My understanding is that attention_query_mask has a shape [N, 1].
You are right that causal_1d_pattern(attention_query_mask.shape[1]) will always be causal_1d_pattern(1) since the documentation says that the shape is [N, 1].
So I think it is indeed a bug.