[BUG] CausalLanguageModeling do not mask last input item
Bug description
The clm masking for last item only do not mask last item in input.
It will cause using the embedding of the label instead of mask.
I think following code needs to be fixed. https://github.com/NVIDIA-Merlin/Transformers4Rec/blob/348c9636399535c566d20e8ebff2b7aa0775f136/transformers4rec/torch/masking.py#L298
Steps/Code to reproduce bug
import torch
from transformers4rec.torch import masking
item_ids = torch.tensor([[1, 2, 0], ])
mask = masking.CausalLanguageModeling(hidden_size=10, train_on_last_item_seq_only=True)
masking_info = mask.compute_masked_targets(item_ids, training=True)
print(masking_info)
MaskingInfo(schema=tensor([[ True, True, False]]), targets=tensor([[2, 0, 0]]))
Expected behavior
MaskingInfo(schema=tensor([[ True, False, False]]), targets=tensor([[2, 0, 0]]))
Environment details
- Transformers4Rec version: 23.08.00
Additional context
I think this line of code need to be removed:
https://github.com/NVIDIA-Merlin/Transformers4Rec/blob/d0cce61b988a1e923545f94ac22499fb57928d18/transformers4rec/torch/masking.py#L298
As solution just use the mask_label from predict_all() above.
And I think the reason why current code somehow works is because of this part:
https://github.com/NVIDIA-Merlin/Transformers4Rec/blob/d0cce61b988a1e923545f94ac22499fb57928d18/transformers4rec/torch/masking.py#L318-L337
Given input sequence without padding [1,2,3], the mask schema generated by current code during evaluation will be [True, True, True], which exposes the last item. However the apply_mask_to_inputs will replace the last item with 0 embedding. And since the schema are all True, no mask embedding will be applied on input. I think in this case 0 embedding sort of plays a role as mask.
However, when input has padding like [1,2,3,0,0], the current mask schema will be [True, True, True, False, False]. And because the last item is a padding item, the apply_mask_to_inputs basically replaces the padding with 0 embedding. Then the mask schema comes in, masks the last 2 padding items, keeping the 1,2,3 visible to transformer.
I think thats why people encounter issues testing clm. If there are always paddings in input data, the evaluation metrics would be unrealistically high.
I also noticed this bug as well. After the fix, the recall is down about 20% less
Any further updates? It seems #723 still not solve this bug.