latent-diffusion icon indicating copy to clipboard operation
latent-diffusion copied to clipboard

Potential bug in cross-attention layer for class-conditioning

Open jhbastek opened this issue 2 years ago • 2 comments

While trying to work through the details of the code, I noticed a potential bug in the forward call of the CrossAttention class (in ldm/modules/attention.py, code provided below).

Namely, since the context tensor has a dimension of (b, 1, h * d) and thus the corresponding k and v a dimension of (b * h, 1, d), sim will have a dimension of (b * h, n, 1). Taking sim.softmax on the last dim will then lead to an all-ones tensor!? Perhaps this is indeed expected behavior since we do not consider any sequence of words here, but then why is it emphasized in the paper that the 'class-conditional model [...] is also implemented via cross-attention...'?

I would greatly appreciate it if the authors could provide some clarity. Thanks in advance!

def forward(self, x, context=None, mask=None):
      h = self.heads

      q = self.to_q(x)
      context = default(context, x)
      k = self.to_k(context)
      v = self.to_v(context)

      q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

      sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

      if exists(mask):
          mask = rearrange(mask, 'b ... -> b (...)')
          max_neg_value = -torch.finfo(sim.dtype).max
          mask = repeat(mask, 'b j -> (b h) () j', h=h)
          sim.masked_fill_(~mask, max_neg_value)

      # attention, what we cannot get enough of
      attn = sim.softmax(dim=-1)

jhbastek avatar Sep 12 '22 17:09 jhbastek

You need to set the n_classes input in your cond_stage_config. See your ldm.modules.encoders.modules.ClassEmbedder if you are using that. They use the default 1000 because they only used that module on imagenet. Still I think it is a bug because I believe its supposed to take OHE vectors as input, so batch[key][:, None] needs to be changed to torch.nn.functional.one_hot(batch[key], n_classes), assuming your class label is an int. This would make the softmax output make sense, but im still experimenting.

Im guessing conditional image generation wasnt a focus of their work, but a bit weird since ~90% of the code is there... you could also technically make this into a multi task problem by setting the values for the positive classes accordingly (e.g. cond 1 -> True, cond 2 -> False, cond 3 -> True would yield [1,0,1] that is used to generate the embedding).

It is unclear to me how their result could distinguish between classes; my guess is that the cond embedding is used elsewhere, and given the learned embedding weights its still able to distinguish between classes. My guess is that they did this to reduce the dimensionality of the cond embeddings; if you use OHE on imagenet, you would have a tensor of [B, 1001, n_feature_emb]. Still, I would think that the additional cross attention step with the OHE embeddings should improve results.

pcicales avatar Oct 08 '22 05:10 pcicales

@jhbastek I ended up generating the OHE with an input of num_classes + 1, with the last positional encoding representing the null class for unconditional embedding generation. I set the null class input to 1 always; I am still validating my results.

pcicales avatar Oct 11 '22 19:10 pcicales

Hey @pcicales, could you please share any insights you got from your experiment? I've been thinking about this too, and would really appreciate any pointers!

jaivardhankapoor avatar Feb 01 '23 14:02 jaivardhankapoor