LongWriter icon indicating copy to clipboard operation
LongWriter copied to clipboard

eos_indices = input_ids.argmin(dim=1) - 1

Open xpdd123 opened this issue 1 year ago • 2 comments

script/main.py中 class DataCollatorForLMDataset(object):

def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
    input_ids, labels = tuple([instance[key].unsqueeze(0) for instance in instances] for key in ("input_ids", "labels"))
    input_ids = torch.cat(input_ids, dim=0)
    labels = torch.cat(labels, dim=0)
    eos_indices = input_ids.argmin(dim=1) - 1
    max_position = eos_indices.max()
    if max_position < 0:
        return dict(
            input_ids=input_ids,
            labels=labels
        )
    return dict(
        input_ids=input_ids[:, :max_position+1],
        labels=labels[:, :max_position+1]
    )

这里,为什么 "eos_indices = input_ids.argmin(dim=1) - 1",

但是在sort_and_group.py中, eos_indice = (input_id == EOS_ID).int().argmax().item()

xpdd123 avatar Aug 29 '24 08:08 xpdd123

这两行作用相同,可以统一为eos_indice = (input_id == EOS_ID).int().argmax().item()

bys0318 avatar Aug 29 '24 14:08 bys0318

谢谢提醒,我这两天更新代码

bys0318 avatar Aug 29 '24 14:08 bys0318