Restormer icon indicating copy to clipboard operation
Restormer copied to clipboard

Wrong attention implementation?

Open KohakuBlueleaf opened this issue 2 years ago • 4 comments

In the restormer.py We can see attention is implemented as below:

qkv = self.qkv_dwconv(self.qkv(x))
q,k,v = qkv.chunk(3, dim=1)   

q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)

q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1) * self.temperature

attn = (q @ k.transpose(-2, -1)) 
attn = attn.softmax(dim=-1)

out = (attn @ v)
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)

out = self.project_out(out)

We can see q,k,v have (b, c, h, w) at first and then we split the channels(dims) into num_heads Which means definitely the c in "b (head c) h w" is the "head dim" for attention

BUT, why we use "b head c (h w)" which put "sequence" at the end and and doing normal attention on it?

I think this is wrong (you want "b head (h w) c") and can explain why some ppl meet problems when training with different size images.

KohakuBlueleaf avatar Oct 12 '23 08:10 KohakuBlueleaf

这是来自QQ邮箱的假期自动回复邮件。你好,我最近正在休假中,无法亲自回复你的邮件。我将在假期结束后,尽快给你回复。

Mytttttttt avatar Oct 12 '23 08:10 Mytttttttt

so we need change code to this?

    def forward(self, x):
        b, c, h, w = x.shape

        qkv = self.qkv_dwconv(self.qkv(x))
        q, k, v = qkv.chunk(3, dim=1)

        q = rearrange(q, 'b (head c) h w -> b head (h w) c', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head (h w) c', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head (h w) c', head=self.num_heads)

        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)

        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)

        out = (attn @ v)

        out = rearrange(out, 'b head c (h w) -> b head (h w) c', head=self.num_heads, h=h, w=w)

        out = self.project_out(out)
        return out

morestart avatar Apr 01 '24 15:04 morestart

so we need change code to this?

    def forward(self, x):
        b, c, h, w = x.shape

        qkv = self.qkv_dwconv(self.qkv(x))
        q, k, v = qkv.chunk(3, dim=1)

        q = rearrange(q, 'b (head c) h w -> b head (h w) c', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head (h w) c', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head (h w) c', head=self.num_heads)

        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)

        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)

        out = (attn @ v)

        out = rearrange(out, 'b head c (h w) -> b head (h w) c', head=self.num_heads, h=h, w=w)

        out = self.project_out(out)
        return out

Yes I think so

KohakuBlueleaf avatar Apr 01 '24 15:04 KohakuBlueleaf

but it will get some other errors

morestart avatar Apr 01 '24 15:04 morestart

No!!! please read paper section 3.1 ......

JuWanMaeng avatar Apr 30 '24 00:04 JuWanMaeng

这是来自QQ邮箱的假期自动回复邮件。你好,我最近正在休假中,无法亲自回复你的邮件。我将在假期结束后,尽快给你回复。

Mytttttttt avatar Apr 30 '24 00:04 Mytttttttt

No!!! please read paper section 3.1 ......

You are right Than the thing is this arch will be hugely affect by resolution settings. So other issue about that become "expected" and I don't think the paper have any annotation about this.

KohakuBlueleaf avatar Apr 30 '24 03:04 KohakuBlueleaf