trl icon indicating copy to clipboard operation
trl copied to clipboard

minibatching changes and masking

Open raj47212 opened this issue 2 years ago • 1 comments

I have a question on the latest #153 minibatching changes in batched_forward_pass() on how the masks are computed for encoder_decoder models. There appears to be an extra element removed from the mask and I am trying to understand the reason.

Assume a BART model generates a response, and the response tensors are of len 42, starts with a 2 (</s>) and ends the seq with a 2. The prepare_model_inputs() creates a decoder_input_ids of shape bx42 and an all 1 decoder_attention_mask. The masks have the 1st element set to zero (comment line 622), and in the returned values we ignore the last element (effectively returning a bx41 shape).

So the 2 at the beginning is masked out (with mask value 0 at position 0) and the last 2 is completely ignored (because we return [:-1]). This also affects compute_rewards() which relies on last_non_masked_index. Is this by design?

raj47212 avatar Feb 25 '23 03:02 raj47212

So when we do the forward pass we actually predict one more token than we generated. E.g. when inputing 3 tokens the model will also predict a 4th token which is not needed since we just want to evaluate the 3 generated ones. Does that make sense?

lvwerra avatar Mar 02 '23 13:03 lvwerra