Open-Sora icon indicating copy to clipboard operation
Open-Sora copied to clipboard

为什么魔改DiT?

Open qiuyang163 opened this issue 11 months ago • 20 comments

image DiT原文结构采用了AdaLN,condition内的class label不包含sequence维度,因此需要加入cross attention才能处理复杂文本序列和patch序列的关系。但是本项目的实现魔改了这个结构,把patch的self attention直接修改为patch序列和text condition的cross attention,忽略了patch序列的self attention,这样做的目的是? 忽略patch的self attention,不会有帧生成的质量问题吗?

qiuyang163 avatar Mar 06 '24 09:03 qiuyang163

@binmakeswell @ver217

qiuyang163 avatar Mar 06 '24 09:03 qiuyang163

没懂,patch的self attention不是做了吗。

af-74413592 avatar Mar 06 '24 11:03 af-74413592

image 和原版dit一模一样的写法啊

af-74413592 avatar Mar 06 '24 12:03 af-74413592

image 和原版dit一模一样的写法啊

还真是不一样。。。 image

af-74413592 avatar Mar 06 '24 12:03 af-74413592

image 确实都改掉了。。。不做self attn,光有cross attn 不太行吧

af-74413592 avatar Mar 06 '24 12:03 af-74413592

image 原版其实就是t+y 作为c,做一下shift和scale,然后self attn

af-74413592 avatar Mar 06 '24 12:03 af-74413592

image 这边完全改成了cross了。

af-74413592 avatar Mar 06 '24 12:03 af-74413592

是的,我这里也是有点疑问,希望作者给解答吧

qiuyang163 avatar Mar 06 '24 12:03 qiuyang163

image opensora plan那边的实现

af-74413592 avatar Mar 06 '24 12:03 af-74413592

https://github.com/neonsecret/DiTFusion?tab=readme-ov-file 这个代码库拥有cross attention的实现,附带256 分辨率的DiT预训练权重

qiuyang163 avatar Mar 06 '24 12:03 qiuyang163

是的,我这里也是有点疑问,希望作者给解答吧

image 最大的可能性因为是text latent, kvcache会变得很小,只有Q很大,但是这样硬压复杂度估计效果不会好的。

af-74413592 avatar Mar 06 '24 12:03 af-74413592

https://github.com/neonsecret/DiTFusion?tab=readme-ov-file 这个代码库拥有cross attention的实现,附带256 分辨率的DiT预训练权重

image 不是的他是先做了self,在做的cross。

af-74413592 avatar Mar 06 '24 12:03 af-74413592

https://github.com/neonsecret/DiTFusion?tab=readme-ov-file 这个代码库拥有cross attention的实现,附带256 分辨率的DiT预训练权重

image 不是的他是先做了self,在做的cross。

是的,这个应该DiT 带有cross attention的标准实现,先self atten,再 cross attn

qiuyang163 avatar Mar 06 '24 13:03 qiuyang163

光有cross 3d图像块本身的特征几乎就没有学习了。但是也不会说完全不能训,掉点估计会很严重吧。

af-74413592 avatar Mar 06 '24 13:03 af-74413592

the PIXART-alpha's structure should be the right one

quantumiracle avatar Mar 07 '24 02:03 quantumiracle

We support both these two methods. Just set use_cross_attn=False, then it will use in-context conditioning (token concat). If using adaLN-zero, we cannot use sequential text features, instead, we can only use pooled text features. We may implement this if we find it useful.

ver217 avatar Mar 07 '24 02:03 ver217

I've updated training script, you can choose model arch by using -x option

ver217 avatar Mar 07 '24 12:03 ver217

I guess people are looking for (in this thread):

        if self.add_self_atten:
            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
                self.adaLN_modulation(t).chunk(6, dim=1)
            )
            # multi-head self-attention
            x = x + gate_msa.unsqueeze(1) * self.attn_self(
                modulate(self.norm0(x), shift_msa, scale_msa), context=None, mask=self_attention_mask
            )
            # multi-head cross-attention
            x = x + self.attn(self.norm1(x), context, attention_mask
            )
            # feed-forward
            x = x + gate_mlp.unsqueeze(1) * self.mlp(
                modulate(self.norm2(x), shift_mlp, scale_mlp)
            )

in DitBlock.forward() when -x='cross-attn'

quantumiracle avatar Mar 07 '24 16:03 quantumiracle

I guess people are looking for (in this thread):

        if self.add_self_atten:
            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
                self.adaLN_modulation(t).chunk(6, dim=1)
            )
            # multi-head self-attention
            x = x + gate_msa.unsqueeze(1) * self.attn_self(
                modulate(self.norm0(x), shift_msa, scale_msa), context=None, mask=self_attention_mask
            )
            # multi-head cross-attention
            x = x + self.attn(self.norm1(x), context, attention_mask
            )
            # feed-forward
            x = x + gate_mlp.unsqueeze(1) * self.mlp(
                modulate(self.norm2(x), shift_mlp, scale_mlp)
            )

in DitBlock.forward() when -x='cross-attn'

Actually found this will require huge memory (~10^3 GB VRAM) in expand_mask_4d() function for creating self_attention_mask. Another finding is with argument -x='token-concat' it will also requires huge VRAM thus not feasible, any idea on how to solve this?

quantumiracle avatar Mar 07 '24 20:03 quantumiracle

Any update?

Even though the code was updated last week, I still see that the ‘cross-attn’ structure is incorrect which lacks 'self-attn' before 'cross-attn'

howardgriffin avatar Mar 14 '24 07:03 howardgriffin

We have updated our code. Now we use a PixArt-based DiT structure, with inserted temporal attention.

zhengzangw avatar Mar 18 '24 05:03 zhengzangw