DiT icon indicating copy to clipboard operation
DiT copied to clipboard

Is adaLN applicable to text condition?

Open Darius-H opened this issue 2 years ago • 3 comments

In DiTBlock: self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) )

def forward(self, x, c):
    shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) 
    x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
    x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
    return x

I wonder if adaLN_modulation could be used in text condition, in which the shape of c is [batch_size,max_tokens_len,hidden_size]. Do you think adaLN can be used in text condition? In which way? Or have you tried DiT in text-to-image generation?

Darius-H avatar Feb 28 '23 14:02 Darius-H

Hi @Darius-H. Yeah, it could be directly used for text conditioning. For example, if you encode the input text with a pre-trained LLM, you can pool the output text tokens into a single embedding and run adaLN conditioning with the pooled embedding as input. I assume to get the best results, you'd also need to add cross attention against the LLM's output text tokens (in addition to adaLN-zero on the pooled representation), but it's still an open research question

wpeebles avatar Mar 01 '23 05:03 wpeebles

Hi @Darius-H. Yeah, it could be directly used for text conditioning. For example, if you encode the input text with a pre-trained LLM, you can pool the output text tokens into a single embedding and run adaLN conditioning with the pooled embedding as input. I assume to get the best results, you'd also need to add cross attention against the LLM's output text tokens (in addition to adaLN-zero on the pooled representation), but it's still an open research question

If applying this strategy, would you also extend the adaLN modulation to the cross attention block (i.e. additionally output shift_crossattn, scale_crossattn, gate_crossattn) or just apply modulation to the attention and mlp blocks as usual?

tomresan avatar Nov 08 '23 15:11 tomresan

Hi @Darius-H. Yeah, it could be directly used for text conditioning. For example, if you encode the input text with a pre-trained LLM, you can pool the output text tokens into a single embedding and run adaLN conditioning with the pooled embedding as input. I assume to get the best results, you'd also need to add cross attention against the LLM's output text tokens (in addition to adaLN-zero on the pooled representation), but it's still an open research question

If applying this strategy, would you also extend the adaLN modulation to the cross attention block (i.e. additionally output shift_crossattn, scale_crossattn, gate_crossattn) or just apply modulation to the attention and mlp blocks as usual?

same question

zhuyy1168 avatar Mar 01 '24 13:03 zhuyy1168