Wrong attention implementation?
In the restormer.py We can see attention is implemented as below:
qkv = self.qkv_dwconv(self.qkv(x))
q,k,v = qkv.chunk(3, dim=1)
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1) * self.temperature
attn = (q @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
out = (attn @ v)
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
out = self.project_out(out)
We can see q,k,v have (b, c, h, w) at first and then we split the channels(dims) into num_heads Which means definitely the c in "b (head c) h w" is the "head dim" for attention
BUT, why we use "b head c (h w)" which put "sequence" at the end and and doing normal attention on it?
I think this is wrong (you want "b head (h w) c") and can explain why some ppl meet problems when training with different size images.
这是来自QQ邮箱的假期自动回复邮件。你好,我最近正在休假中,无法亲自回复你的邮件。我将在假期结束后,尽快给你回复。
so we need change code to this?
def forward(self, x):
b, c, h, w = x.shape
qkv = self.qkv_dwconv(self.qkv(x))
q, k, v = qkv.chunk(3, dim=1)
q = rearrange(q, 'b (head c) h w -> b head (h w) c', head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head (h w) c', head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head (h w) c', head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
out = (attn @ v)
out = rearrange(out, 'b head c (h w) -> b head (h w) c', head=self.num_heads, h=h, w=w)
out = self.project_out(out)
return out
so we need change code to this?
def forward(self, x): b, c, h, w = x.shape qkv = self.qkv_dwconv(self.qkv(x)) q, k, v = qkv.chunk(3, dim=1) q = rearrange(q, 'b (head c) h w -> b head (h w) c', head=self.num_heads) k = rearrange(k, 'b (head c) h w -> b head (h w) c', head=self.num_heads) v = rearrange(v, 'b (head c) h w -> b head (h w) c', head=self.num_heads) q = torch.nn.functional.normalize(q, dim=-1) k = torch.nn.functional.normalize(k, dim=-1) attn = (q @ k.transpose(-2, -1)) * self.temperature attn = attn.softmax(dim=-1) out = (attn @ v) out = rearrange(out, 'b head c (h w) -> b head (h w) c', head=self.num_heads, h=h, w=w) out = self.project_out(out) return out
Yes I think so
but it will get some other errors
No!!! please read paper section 3.1 ......
这是来自QQ邮箱的假期自动回复邮件。你好,我最近正在休假中,无法亲自回复你的邮件。我将在假期结束后,尽快给你回复。
No!!! please read paper section 3.1 ......
You are right Than the thing is this arch will be hugely affect by resolution settings. So other issue about that become "expected" and I don't think the paper have any annotation about this.