BotCL
BotCL copied to clipboard
Slot attention implementation did not update the q?
Based on your implementation, why does the slot attention iterate with the same q & k?
https://github.com/wbw520/BotCL/blob/3dde3ac20cdecd7eea8c4b7cb0e04e2bb95f639b/model/contrast/slots.py#L37
def forward(self, inputs_pe, inputs, weight=None, things=None):
b, n, d = inputs_pe.shape
slots = self.initial_slots.expand(b, -1, -1)
k, v = self.to_k(inputs_pe), inputs_pe
for _ in range(self.iters):
q = slots # always taking the initial slots as q?
dots = torch.einsum('bid,bjd->bij', q, k) * self.scale
dots = torch.div(dots, torch.abs(dots).sum(2).expand_as(dots.permute([2, 0, 1])).permute([1, 2, 0])) * \
torch.abs(dots).sum(2).sum(1).expand_as(dots.permute([1, 2, 0])).permute([2, 0, 1])
attn = torch.sigmoid(dots)
# print(torch.max(attn))
# dsfds()
attn2 = attn / (attn.sum(dim=-1, keepdim=True) + self.eps)
updates = torch.einsum('bjd,bij->bid', inputs, attn2)
if self.vis:
slots_vis_raw = attn.clone()
vis(slots_vis_raw, "vis", self.args.feature_size, weight, things)
return updates, attn