axial-attention
axial-attention copied to clipboard
Extracting attention maps
Hi there,
Excellent project!
I'm using axial-attention with video (1, 5, 128, 256, 256) and sum_axial_out=True
, and I wish to visualise the attention maps.
Essentially, given my video, and two frame indices frame_a_idx
and frame_b_idx
, I need to extract the attention map over frame_b to a chosen pixel (x
, y
) in frame_a (after the axial sum).
My understanding is that I should be able to reshape the dots
(after softmax) according to the permutations in calculate_permutations
, then sum these permuted dots together to form a final attention score tensor of an accessible shape, thus ready for visualisation.
I am slightly stuck due to the numerous axial permutations and shape mismatches. What I am doing is as follows:
In SelfAttention.forward()
:
dots_reshaped = dots.reshape(b, h, t, t)
return out, dots_reshaped
In PermuteToFrom.forward()
:
# attention
axial, dots = self.fn(axial, **kwargs)
# restore to original shape and permutation
axial = axial.reshape(*shape)
axial = axial.permute(*self.inv_permutation).contiguous()
dots = dots.reshape(*shape[:3], *dots.shape[1:])
However, I am unsure of how to un-permute the dots appropriately such that all resulting “axes” (of different sizes) can be summed. If you have suggestions or code for doing so, it would be very much appreciated, thanks!
@vibrant-galaxy i'm not actually sure if it will be too interpretable as it is, since attention is done along each axis separately, and information can take up to two steps to be routed.
however, i think what may be worth trying (and I haven't built it into this repo yet) is to do axial attention and then expand the attention map of each axis along the other axis and then sum, softmax, aggregate values. perhaps it could lead to something more interpretable, as you would have the full attention map. would you be interested in trying this if i were to build it?
That sounds like a good approach to get the full map. Yes, I am very much interested in trying that!
I tried to do something like the below, but it actually goes out of memory when you try to expand and sum the pre-attention maps
So basically I don't think it's possible lol, unless if you see a way to make it work
import torch
from torch import einsum, nn
from einops import rearrange
class AxialAttention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_height_k = nn.Linear(dim, inner_dim, bias = False)
self.to_width_k = nn.Linear(dim, inner_dim, bias = False)
self.to_frame_k = nn.Linear(dim, inner_dim, bias = False)
self.to_v = nn.Linear(dim, inner_dim, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
def forward(self, x):
heads, b, f, c, h, w = self.heads, *x.shape
x = rearrange(x, 'b f c h w -> b f h w c')
q = self.to_q(x)
k_height = self.to_height_k(x)
k_width = self.to_width_k(x)
k_frame = self.to_frame_k(x)
v = self.to_v(x)
q, k_height, k_width, k_frame, v = map(lambda t: rearrange(t, 'b f x y (h d) -> (b h) f x y d', h = heads), (q, k_height, k_width, k_frame, v))
q *= q.shape[-1] ** -0.5
sim_frame = einsum('b f h w d, b j h w d -> b f h w j', q, k_frame)
sim_frame = sim_frame[..., :, None, None].expand(-1, -1, -1, -1, -1, h, w)
sim_height = einsum('b f h w d, b f k w d -> b f h w k', q, k_height)
sim_height = sim_height[..., None, :, None].expand(-1, -1, -1, -1, f, -1, w)
sim_width = einsum('b f h w d, b f h l d -> b f h w l', q, k_width)
sim_width = sim_width[..., None, None, :].expand(-1, -1, -1, -1, f, h, -1)
sim = rearrange(sim_frame + sim_height + sim_width, 'b f h w j k l -> b f h w (j k l)')
attn = sim.softmax(dim = -1)
attn = rearrange(attn, 'b f h w (j k l) -> b f h w j k l', j = f, k = h, l = w)
out = einsum('b f h w j k l, b j k l d -> b f h w d', attn, v)
out = rearrange(out, '(b h) f x y d -> b f x y (h d)', h = heads)
out = self.to_out(out)
out = rearrange(out, 'b f x y d -> b f d x y')
return out, attn
layer = AxialAttention(dim = 16)
video = torch.randn(1, 5, 16, 32, 32)
out, attn = layer(video)