Open-Sora
Open-Sora copied to clipboard
fix bug at mha, MaskGenerator; improve ckpt_utils.py
- fix bug in mha when
mask=None. - fix bug in
MaskGenerator. If the video is short, any type of mask has chances to mask all. - use
loggerinload_checkpoint.
Regarding 1, another solution is to change
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
to
q = self.q_linear(x).view(B, -1, self.num_heads, self.head_dim)
if mask is None. Current implementation on the main branch assumes mask is never to be None.