LongWriter
LongWriter copied to clipboard
eos_indices = input_ids.argmin(dim=1) - 1
script/main.py中 class DataCollatorForLMDataset(object):
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key].unsqueeze(0) for instance in instances] for key in ("input_ids", "labels"))
input_ids = torch.cat(input_ids, dim=0)
labels = torch.cat(labels, dim=0)
eos_indices = input_ids.argmin(dim=1) - 1
max_position = eos_indices.max()
if max_position < 0:
return dict(
input_ids=input_ids,
labels=labels
)
return dict(
input_ids=input_ids[:, :max_position+1],
labels=labels[:, :max_position+1]
)
这里,为什么 "eos_indices = input_ids.argmin(dim=1) - 1",
但是在sort_and_group.py中, eos_indice = (input_id == EOS_ID).int().argmax().item()
这两行作用相同,可以统一为eos_indice = (input_id == EOS_ID).int().argmax().item()
谢谢提醒,我这两天更新代码