Geek_Huang
Geek_Huang
So why just need the last token instead of 2 or more? since there are still lot of token in the last row have not been predicted. And I think...
Thank you for you kindly response! Still one question: Code: # forward input preparation new_input_ids.append(input_ids[:, idx_in_input_ids].unsqueeze(-1)) new_position_ids.append(global_idx) local_position_ids.append(idx_in_input_ids) new_input_ids = torch.cat(new_input_ids, dim=1) new_position_ids = torch.tensor(new_position_ids, device=input_ids.device).unsqueeze_(0) local_position_ids = torch.tensor(local_position_ids, device=input_ids.device).unsqueeze_(0)...
请问您后面有尝试吗,bs不为1能正常运行吗