Attention-Augmented-Conv2d icon indicating copy to clipboard operation
Attention-Augmented-Conv2d copied to clipboard

Does here exist some inconsistency about this code ?

Open fakerhbj opened this issue 3 years ago • 1 comments

     # flat_q, flat_k, flat_v
    # (batch_size, Nh, height * width, dvh or dkh)

   flat_q = torch.reshape(q, (N, Nh, dk // Nh, H * W))
    flat_k = torch.reshape(k, (N, Nh, dk // Nh, H * W))
    flat_v = torch.reshape(v, (N, Nh, dv // Nh, H * W))

fakerhbj avatar Mar 09 '21 03:03 fakerhbj

# attn_out
# (batch, Nh, height * width, dvh)
attn_out = torch.matmul(weights, flat_v.transpose(2, 3)) # shape: (batch, Nh, height * width, dvh)
attn_out = torch.reshape(attn_out, (batch, self.Nh, self.dv // self.Nh, height, width))

I have similar questions about this part as well. Doesn't reshaping like this messes up the order of values?

Really hope to get some clarification on this. Many thanks.

linusyh avatar Apr 08 '21 18:04 linusyh