Set different window sizes and dilations for different tokens
In many tasks, not all tokens have the same importance or require the same amount of computation. Therefore, people often perform routing for different tokens/windows to reduce the computation in less important areas. NAT, which implements window attention centered on each token, can achieve more fine-grained routing. For example, for less important tokens, I can use a smaller window size to reduce computation. Although this may result in not all windows being able to be processed in parallel. I thinks it's valuable if NAT to allow users to control the window size and dilation for each token.
e.g. x : [B, H, W, d] binary_mask: [B, H, W] x = self.neighbourhood_attention(x, binary_mask)
For tokens where the binary_mask is 1, use an attention with a window size of 7 and dilation of 1. For tokens where the binary_mask is 0, use an attention with a window size of 0 and dilation of 2. The two types of window sizes and dilations can be configured by the user when instantiating the NeighborhoodAttention2D class.
This feature is likely going to be very slow if implemented. You might have better luck just running NA with the largest window size and mask the results; or just use standard attention with masking.
I'm generally skeptical of very fine-grained attention patterns because their applications are relatively limited in in my opinion (though I'm curious if there's a different view). I would also expect masking to be the better solution there in that it doesn't require re-implementation, and it will likely be as fast as you can get it.
I'm also curious what you mean by window size 0? The lower bound for window size is greater than 1 (window size 1 means every token only attends to itself, which is a linear projection, and no longer attention). Also confused about how window size and dilation are inferred from a binary mask, but I could be missing something there.
Closing due to inactivity. Feel free to reopen if you still have questions.