zero_nlp
zero_nlp copied to clipboard
get_masks_and_position_ids 问题请教
请问下,get_masks_and_position_ids 这个函数里面, attention_mask = torch.ones((1, context_length, context_length), device=device)
为什么 mask的shape 是 (1,context_length, context_length)呢? 之前只用过bert,attention mask 都是tokenizer自己返回的。 这里的实现原理从哪可以学习下嘛?
感谢~