DiT
DiT copied to clipboard
Is adaLN applicable to text condition?
In DiTBlock: self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) )
def forward(self, x, c):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
I wonder if adaLN_modulation could be used in text condition, in which the shape of c is [batch_size,max_tokens_len,hidden_size]. Do you think adaLN can be used in text condition? In which way? Or have you tried DiT in text-to-image generation?
Hi @Darius-H. Yeah, it could be directly used for text conditioning. For example, if you encode the input text with a pre-trained LLM, you can pool the output text tokens into a single embedding and run adaLN conditioning with the pooled embedding as input. I assume to get the best results, you'd also need to add cross attention against the LLM's output text tokens (in addition to adaLN-zero on the pooled representation), but it's still an open research question
Hi @Darius-H. Yeah, it could be directly used for text conditioning. For example, if you encode the input text with a pre-trained LLM, you can pool the output text tokens into a single embedding and run adaLN conditioning with the pooled embedding as input. I assume to get the best results, you'd also need to add cross attention against the LLM's output text tokens (in addition to adaLN-zero on the pooled representation), but it's still an open research question
If applying this strategy, would you also extend the adaLN modulation to the cross attention block (i.e. additionally output shift_crossattn, scale_crossattn, gate_crossattn) or just apply modulation to the attention and mlp blocks as usual?
Hi @Darius-H. Yeah, it could be directly used for text conditioning. For example, if you encode the input text with a pre-trained LLM, you can pool the output text tokens into a single embedding and run adaLN conditioning with the pooled embedding as input. I assume to get the best results, you'd also need to add cross attention against the LLM's output text tokens (in addition to adaLN-zero on the pooled representation), but it's still an open research question
If applying this strategy, would you also extend the adaLN modulation to the cross attention block (i.e. additionally output
shift_crossattn,scale_crossattn,gate_crossattn) or just apply modulation to the attention and mlp blocks as usual?
same question