Attention-Augmented-Conv2d
Attention-Augmented-Conv2d copied to clipboard
Does here exist some inconsistency about this code ?
# flat_q, flat_k, flat_v
# (batch_size, Nh, height * width, dvh or dkh)
flat_q = torch.reshape(q, (N, Nh, dk // Nh, H * W))
flat_k = torch.reshape(k, (N, Nh, dk // Nh, H * W))
flat_v = torch.reshape(v, (N, Nh, dv // Nh, H * W))
# attn_out
# (batch, Nh, height * width, dvh)
attn_out = torch.matmul(weights, flat_v.transpose(2, 3)) # shape: (batch, Nh, height * width, dvh)
attn_out = torch.reshape(attn_out, (batch, self.Nh, self.dv // self.Nh, height, width))
I have similar questions about this part as well. Doesn't reshaping like this messes up the order of values?
Really hope to get some clarification on this. Many thanks.