NEKO
NEKO copied to clipboard
Possible use of padding instead of max tokens to avoid error when calculating loss.
TODO: Add details. This is just a rough draft of something Henry and I were talking about over a screen share.
def pad(predicted, target):
torch.tensor()
if len(target) > len(predicted):
return target, F.pad(predicted, (0, len(target) - len(predicted)), 'constant', 0)
else:
return F.pad(target, (0, len(predicted) - len(target)), 'constant', 0), predicted