Open-Sora
Open-Sora copied to clipboard
为什么魔改DiT?
DiT原文结构采用了AdaLN,condition内的class label不包含sequence维度,因此需要加入cross attention才能处理复杂文本序列和patch序列的关系。但是本项目的实现魔改了这个结构,把patch的self attention直接修改为patch序列和text condition的cross attention,忽略了patch序列的self attention,这样做的目的是?
忽略patch的self attention,不会有帧生成的质量问题吗?
@binmakeswell @ver217
没懂,patch的self attention不是做了吗。
和原版dit一模一样的写法啊
和原版dit一模一样的写法啊
还真是不一样。。。
确实都改掉了。。。不做self attn,光有cross attn 不太行吧
原版其实就是t+y 作为c,做一下shift和scale,然后self attn
这边完全改成了cross了。
是的,我这里也是有点疑问,希望作者给解答吧
opensora plan那边的实现
https://github.com/neonsecret/DiTFusion?tab=readme-ov-file 这个代码库拥有cross attention的实现,附带256 分辨率的DiT预训练权重
是的,我这里也是有点疑问,希望作者给解答吧
最大的可能性因为是text latent, kvcache会变得很小,只有Q很大,但是这样硬压复杂度估计效果不会好的。
https://github.com/neonsecret/DiTFusion?tab=readme-ov-file 这个代码库拥有cross attention的实现,附带256 分辨率的DiT预训练权重
不是的他是先做了self,在做的cross。
https://github.com/neonsecret/DiTFusion?tab=readme-ov-file 这个代码库拥有cross attention的实现,附带256 分辨率的DiT预训练权重
不是的他是先做了self,在做的cross。
是的,这个应该DiT 带有cross attention的标准实现,先self atten,再 cross attn
光有cross 3d图像块本身的特征几乎就没有学习了。但是也不会说完全不能训,掉点估计会很严重吧。
the PIXART-alpha's structure should be the right one
We support both these two methods. Just set use_cross_attn=False
, then it will use in-context conditioning (token concat). If using adaLN-zero, we cannot use sequential text features, instead, we can only use pooled text features. We may implement this if we find it useful.
I've updated training script, you can choose model arch by using -x
option
I guess people are looking for (in this thread):
if self.add_self_atten:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.adaLN_modulation(t).chunk(6, dim=1)
)
# multi-head self-attention
x = x + gate_msa.unsqueeze(1) * self.attn_self(
modulate(self.norm0(x), shift_msa, scale_msa), context=None, mask=self_attention_mask
)
# multi-head cross-attention
x = x + self.attn(self.norm1(x), context, attention_mask
)
# feed-forward
x = x + gate_mlp.unsqueeze(1) * self.mlp(
modulate(self.norm2(x), shift_mlp, scale_mlp)
)
in DitBlock.forward() when -x='cross-attn'
I guess people are looking for (in this thread):
if self.add_self_atten: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( self.adaLN_modulation(t).chunk(6, dim=1) ) # multi-head self-attention x = x + gate_msa.unsqueeze(1) * self.attn_self( modulate(self.norm0(x), shift_msa, scale_msa), context=None, mask=self_attention_mask ) # multi-head cross-attention x = x + self.attn(self.norm1(x), context, attention_mask ) # feed-forward x = x + gate_mlp.unsqueeze(1) * self.mlp( modulate(self.norm2(x), shift_mlp, scale_mlp) )
in DitBlock.forward() when -x='cross-attn'
Actually found this will require huge memory (~10^3 GB VRAM) in expand_mask_4d()
function for creating self_attention_mask. Another finding is with argument -x='token-concat'
it will also requires huge VRAM thus not feasible, any idea on how to solve this?
Any update?
Even though the code was updated last week, I still see that the ‘cross-attn’ structure is incorrect which lacks 'self-attn' before 'cross-attn'
We have updated our code. Now we use a PixArt-based DiT structure, with inserted temporal attention.