BotCL icon indicating copy to clipboard operation
BotCL copied to clipboard

Slot attention implementation did not update the q?

Open kamwoh opened this issue 1 year ago • 1 comments

Based on your implementation, why does the slot attention iterate with the same q & k?

    https://github.com/wbw520/BotCL/blob/3dde3ac20cdecd7eea8c4b7cb0e04e2bb95f639b/model/contrast/slots.py#L37
    def forward(self, inputs_pe, inputs, weight=None, things=None):
        b, n, d = inputs_pe.shape
        slots = self.initial_slots.expand(b, -1, -1)
        k, v = self.to_k(inputs_pe), inputs_pe
        for _ in range(self.iters):
            q = slots  # always taking the initial slots as q?

            dots = torch.einsum('bid,bjd->bij', q, k) * self.scale
            dots = torch.div(dots, torch.abs(dots).sum(2).expand_as(dots.permute([2, 0, 1])).permute([1, 2, 0])) * \
                   torch.abs(dots).sum(2).sum(1).expand_as(dots.permute([1, 2, 0])).permute([2, 0, 1])
            attn = torch.sigmoid(dots)

            # print(torch.max(attn))
            # dsfds()

            attn2 = attn / (attn.sum(dim=-1, keepdim=True) + self.eps)
            updates = torch.einsum('bjd,bij->bid', inputs, attn2)

        if self.vis:
            slots_vis_raw = attn.clone()
            vis(slots_vis_raw, "vis", self.args.feature_size, weight, things)
        return updates, attn

kamwoh avatar Apr 23 '23 19:04 kamwoh