Lookback-Lens
Lookback-Lens copied to clipboard
The use of teacher_forcing_seq
Dear authors,
Thank you for your work, I really liked it and started delving into the code. However, it is unclear for me where "teacher_forcing_seq" is used and why.
I can clearly see this as an argument for generation in the vanilla case and that it is loaded only when the argument is passed, but you do not pass this argument in README, that confuses a bit.
It is also a bit unclear what it does and do I really need to understand this :)
if teacher_forcing_seq is None:
# argmax
next_tokens = torch.argmax(next_tokens_scores, dim=-1)
else:
if new_token_idx >= teacher_forcing_seq.shape[1]:
# use eos token as next token
next_tokens = torch.tensor(
[eos_token_id[0] if eos_token_id is not None else pad_token_id] * input_ids.shape[0],
)
new_token_idx += 1
this_peer_finished = True
else:
next_tokens = teacher_forcing_seq[:, new_token_idx]
new_token_idx += 1
Thank you for your help!