Lookback-Lens icon indicating copy to clipboard operation
Lookback-Lens copied to clipboard

The use of teacher_forcing_seq

Open VityaVitalich opened this issue 5 months ago • 1 comments

Dear authors,

Thank you for your work, I really liked it and started delving into the code. However, it is unclear for me where "teacher_forcing_seq" is used and why.

I can clearly see this as an argument for generation in the vanilla case and that it is loaded only when the argument is passed, but you do not pass this argument in README, that confuses a bit.

It is also a bit unclear what it does and do I really need to understand this :)

            if teacher_forcing_seq is None:
                # argmax
                next_tokens = torch.argmax(next_tokens_scores, dim=-1)
            else:
                if new_token_idx >= teacher_forcing_seq.shape[1]:
                    # use eos token as next token
                    next_tokens = torch.tensor(
                        [eos_token_id[0] if eos_token_id is not None else pad_token_id] * input_ids.shape[0],
                    )
                    new_token_idx += 1
                    this_peer_finished = True
                else:
                    next_tokens = teacher_forcing_seq[:, new_token_idx]
                    new_token_idx += 1

Thank you for your help!

VityaVitalich avatar Sep 06 '24 10:09 VityaVitalich