transformers
transformers copied to clipboard
Add sliding window attention to sdpa in mistral
Feature request
https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L1006-L1023
In the code listed above, the latest version of transformers cannot use sliding window feature in mistral model.
I doubt that the reason is you mentioned above,
https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L687-L688
And this issue in PyTorch makes you bugged with custom attn_mask like sliding window attention mask.
https://github.com/pytorch/pytorch/issues/112577
While this issue has been fixed since torch 2.2.0, and it has been released two weeks ago, can you add this feature back to sdpa kernel in mistral?
Motivation
I cannot use sliding window with sdpa right now, cause my gpu card is V100, i cannot work with flashattention2.
Your contribution
I think we can pass sliding_window param to _prepare_4d_causal_attention_mask_for_sdpa function.
cc @fxmarty
Hi, thank you for the suggestion, SDPA support for mistral was added by @ArthurZucker in https://github.com/huggingface/transformers/pull/28133, maybe he has more insight.
I think it comes down to just adding sliding_window
to the call for _prepare_4d_causal_attention_mask_for_sdpa
yes. Would you like to open a PR?
I think it comes down to just adding
sliding_window
to the call for_prepare_4d_causal_attention_mask_for_sdpa
yes. Would you like to open a PR?
Sure,and i'll open a PR later in this week
any plan for pr?
#29407 should fix this issue
@ArthurZucker Oh you are right. Thanks.
Fixed in https://github.com/huggingface/transformers/pull/30127