axial-attention icon indicating copy to clipboard operation
axial-attention copied to clipboard

Extracting attention maps

Open vibrant-galaxy opened this issue 4 years ago • 3 comments

Hi there,

Excellent project!

I'm using axial-attention with video (1, 5, 128, 256, 256) and sum_axial_out=True, and I wish to visualise the attention maps.

Essentially, given my video, and two frame indices frame_a_idx and frame_b_idx, I need to extract the attention map over frame_b to a chosen pixel (x, y) in frame_a (after the axial sum).

My understanding is that I should be able to reshape the dots (after softmax) according to the permutations in calculate_permutations, then sum these permuted dots together to form a final attention score tensor of an accessible shape, thus ready for visualisation.

I am slightly stuck due to the numerous axial permutations and shape mismatches. What I am doing is as follows:

In SelfAttention.forward():

dots_reshaped = dots.reshape(b, h, t, t)
return out, dots_reshaped

In PermuteToFrom.forward():

 # attention
axial, dots = self.fn(axial, **kwargs)

# restore to original shape and permutation
axial = axial.reshape(*shape)
axial = axial.permute(*self.inv_permutation).contiguous()
dots = dots.reshape(*shape[:3], *dots.shape[1:])

However, I am unsure of how to un-permute the dots appropriately such that all resulting “axes” (of different sizes) can be summed. If you have suggestions or code for doing so, it would be very much appreciated, thanks!

vibrant-galaxy avatar Jan 28 '21 12:01 vibrant-galaxy

@vibrant-galaxy i'm not actually sure if it will be too interpretable as it is, since attention is done along each axis separately, and information can take up to two steps to be routed.

however, i think what may be worth trying (and I haven't built it into this repo yet) is to do axial attention and then expand the attention map of each axis along the other axis and then sum, softmax, aggregate values. perhaps it could lead to something more interpretable, as you would have the full attention map. would you be interested in trying this if i were to build it?

lucidrains avatar Jan 30 '21 05:01 lucidrains

That sounds like a good approach to get the full map. Yes, I am very much interested in trying that!

vibrant-galaxy avatar Jan 31 '21 02:01 vibrant-galaxy

I tried to do something like the below, but it actually goes out of memory when you try to expand and sum the pre-attention maps

So basically I don't think it's possible lol, unless if you see a way to make it work

import torch
from torch import einsum, nn
from einops import rearrange

class AxialAttention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.to_q = nn.Linear(dim, inner_dim, bias = False)

        self.to_height_k = nn.Linear(dim, inner_dim, bias = False)
        self.to_width_k = nn.Linear(dim, inner_dim, bias = False)
        self.to_frame_k = nn.Linear(dim, inner_dim, bias = False)

        self.to_v = nn.Linear(dim, inner_dim, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

    def forward(self, x):
        heads, b, f, c, h, w = self.heads, *x.shape

        x = rearrange(x, 'b f c h w -> b f h w c')

        q = self.to_q(x)

        k_height = self.to_height_k(x)
        k_width = self.to_width_k(x)
        k_frame = self.to_frame_k(x)

        v = self.to_v(x)

        q, k_height, k_width, k_frame, v = map(lambda t: rearrange(t, 'b f x y (h d) -> (b h) f x y d', h = heads), (q, k_height, k_width, k_frame, v))

        q *= q.shape[-1] ** -0.5

        sim_frame = einsum('b f h w d, b j h w d -> b f h w j', q, k_frame)
        sim_frame = sim_frame[..., :, None, None].expand(-1, -1, -1, -1, -1, h, w)

        sim_height = einsum('b f h w d, b f k w d -> b f h w k', q, k_height)
        sim_height = sim_height[..., None, :, None].expand(-1, -1, -1, -1, f, -1, w)

        sim_width = einsum('b f h w d, b f h l d -> b f h w l', q, k_width)
        sim_width = sim_width[..., None, None, :].expand(-1, -1, -1, -1, f, h, -1)

        sim = rearrange(sim_frame + sim_height + sim_width, 'b f h w j k l -> b f h w (j k l)')
        attn = sim.softmax(dim = -1)

        attn = rearrange(attn, 'b f h w (j k l) -> b f h w j k l', j = f, k = h, l = w)
        out = einsum('b f h w j k l, b j k l d -> b f h w d', attn, v)

        out = rearrange(out, '(b h) f x y d -> b f x y (h d)', h = heads)
        out = self.to_out(out)
        out = rearrange(out, 'b f x y d -> b f d x y')

        return out, attn

layer = AxialAttention(dim = 16)
video = torch.randn(1, 5, 16, 32, 32)
out, attn = layer(video)

lucidrains avatar Feb 01 '21 01:02 lucidrains