MoonCast icon indicating copy to clipboard operation
MoonCast copied to clipboard

chunk-wise causal mask

Open zyy-fc opened this issue 8 months ago • 2 comments

非常棒的工作!

对文章中的section3.2.2非常感兴趣,请问能够开源跟chunk-wise causal mask相关的训练代码吗?

zyy-fc avatar Mar 30 '25 02:03 zyy-fc

Yes. See the code below.

# x = [x0, xt]
attend_value = 0.0
mask_value = float("-inf")
attn_mask = torch.full((seq_len * 2, seq_len * 2), mask_value) 
index = torch.arange(seq_len * 2)

chunk = torch.randint(min_chunk, max_chunk, (1,)).item()
block_shift = torch.randint(0, chunk, size=(1,))
block_idx = (torch.arange(seq_len) + block_shift) // chunk
block_idx = block_idx.repeat(2)

z_mask = (block_idx[:seq_len].unsqueeze(0) <= block_idx[:seq_len].unsqueeze(1))  

zt_mask_1 = (block_idx.unsqueeze(0) == block_idx[seq_len:].unsqueeze(1)) & (index >= seq_len)
zt_mask_2 = (block_idx.unsqueeze(0) < block_idx[seq_len:].unsqueeze(1)) & (index < seq_len)  
zt_mask = zt_mask_1 | zt_mask_2

attn_mask[:seq_len, :seq_len] = torch.where(z_mask, attend_value, mask_value)
attn_mask[seq_len:, :] = torch.where(zt_mask, attend_value, mask_value)

jzq2000 avatar Mar 31 '25 05:03 jzq2000

您好,了解了推理代码之后,请问chunk-wise autoregressive speech detokenizer这部分的训练代码能开源吗?

如果不能的话,能推荐一个可参考的开源训练项目吗?

zyy-fc avatar Apr 09 '25 03:04 zyy-fc