transformers
transformers copied to clipboard
[In progress] Add warning padding attention mask
What does this PR do?
Fixes #16136
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you read the contributor guideline, Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [ ] Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
cc @gante
Thank you for the comment!
Based on my understanding, this line of code enables the checking process only once during the forward pass, so it should not significantly impact performance.
The current warning method only issue warnings when attention_mask is necessary (due to the presence of padding tokens in the input), but no attention_mask is provided. In other cases where attention_mask is not required, no warning is issued. The additional checking on special tokens allows a more detailed warning message.
I agree that your suggested method is more concise and efficient, but it may generate warnings when attention_mask is not needed.
Since it's my first time contributing to the community, I don't have a strong opinion towards either solution. The original work is by @ydshieh and @patrickvonplaten. Perhaps they have additional insights and can suggest a more effective solution.
Why not a simple logger.warning_once()
This is recently introduced :-)
@anruijian It checks input_ids until there is a batch in which a pad_token_id exists. If a user is working on a problem where they have no pad_token_id on their data and they don't pass the attention_mask, there is a check made every forward pass. I'd strongly advocate for a simple warning when the attention_mask is not passed 🤗
As a side note, we have related problems at other points in the code base. Getting into the habit of passing the attention_mask would really make everyone happier!
@gante Just to confirm before updating the PR, we are going to remove warn_if_pad_token_in_input_ids_no_attention_mask method and use logger.warning_once in forward():
def forward(...):
...
if not attention_mask:
logger.warning_once(
"\nWe strongly recommend passing an `attention_mask` to avoid possibly incorrectly computing the"
" attention weights. "
)
...
@anruijian correct :) I would add a short example in the warning, such as (e.g. to correctly mask the pad tokens), but I'll leave that up to you!
@gante
def forward(...):
...
if not attention_mask:
logger.warning_once(
"\nWe strongly recommend passing an `attention_mask` to avoid possibly incorrectly computing the"
" attention weights. Example to correctly mask the pad tokens: model(input_ids, attention_mask=attention_mask)."
" See https://huggingface.co/docs/transformers/v4.23.1/en/troubleshooting#incorrect-output-when-padding-tokens-arent-masked for more details."
)
...
Does this example look good to you? I also link the official doc on the issue. Not sure if it's too long. Let me know what you think about this. Thanks!
@anruijian sounds good to me! (A minor nit: the link is for v4.23 of the docs, should be https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked instead)
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.