transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Llama: fix custom 4D masks, v2

Open poedator opened this issue 4 months ago • 9 comments

this is an attempt to rebase #29930 started initially by @gante

Fixes the issue raised by @poedator in https://github.com/huggingface/transformers/pull/29753#issuecomment-2014530814.

Causal mask is now of shape [..., seq_len, full_len], as opposed to [..., full_len, full_len]. This means custom 4D attention masks are now the whole causal mask, so we don't need a sliced copy -- we can copy the whole thing :)

This PR also expands the support of custom 4D attention mask: we can pass both the full mask ([..., full_len, full_len]) or the partial mask ([..., seq_len, full_len]).


as of 18.04.24 it is not passing the 4D mask tests because of the _ignore_causal_mask_sdpa() method, most recently edited in #30317 (merged today).

tests/models/llama/test_modeling_llama.py::Mask4DTestHard::test_partial_stacked_causal_mask - ValueError: Incorrect 4D attention_mask shape: (1, 1, 12, 12); expected: (1, 1, 9, 12). apparently _ignore_causal_mask_sdpa() expects that attention_mask.shape[-2] == query_length. This may only be true if the input_ids are contiguous, which is not always the case in some intended 4D mask applications.

tests/models/llama/test_modeling_llama.py::Mask4DTestHard::test_stacked_causal_mask_static_cache - ValueError: Incorrect 4D attention_mask shape: (1, 1, 12, 16); expected: (1, 1, 12, 12). in this test _ignore_causal_mask_sdpa() expects that attention_mask.shape[-1] == key_value_length which is set by past seen tokens. However, in the test I make this dimension equal to the static cache size, so that the mask always has same shape and the whole graph may be compiled and reused.

I hesitate to make edits to _ignore_causal_mask_sdpa() because there may be some greater context.

Summoning @younesbelkada @ArthurZucker @gante to help

poedator avatar Apr 19 '24 15:04 poedator

As s a solution, I added additional expected_shapes to _ignore_causal_mask_sdpa() and improved StaticCache detection code. Note: it is inconvenient to have StaticCache as layer.self_attn objects and other Caches as model-level object. Perhaps there may be a model-level plug to avoid referencing the layer levels.

Please review soon - I need this for my paper code. It's been broken for quite long now.

The LONG tests look OK.

poedator avatar Apr 20 '24 23:04 poedator

all CI tests are green, SLOW tests were OK on my side yesterday

poedator avatar Apr 23 '24 09:04 poedator

I noticed that mistral model support for 4D masks stayed broken after these fixes. So I added similar lines to src/transformers/modeling_attn_mask_utils.py::_prepare_4d_causal_attention_mask_for_sdpa()

poedator avatar Apr 23 '24 12:04 poedator

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

I added Mask4DTestHard tests (without static cache part) to tests/models/mistral/test_modeling_mistral.py to ensure that the 4d masks keep working in the models that use _prepare_4d_causal_attention_mask_for_sdpa(). These new tests would fail without the fixes from commit d488f35 just above. Tested the SLOW tests for ./tests/models/mistral/ branch - all fine

Is there anything left to do before the merge is possible? @gante @ArthurZucker

poedator avatar Apr 24 '24 08:04 poedator

Let's remove unrelated changes!

sorry, but without these changes, the fixes and tests will not work. I looked for related PRs, all I found was #30476 but it is not fixing the relevant parts of the code.

poedator avatar Apr 26 '24 10:04 poedator

I tried to follow Arthur's advice to streamline the path for the 4D masks and it seems to work. The relevant tests do pass. @ArthurZucker @gante , please review

poedator avatar Apr 29 '24 19:04 poedator

I combined the 2 tests from common, which were very similar. Added tolerance - now Mixtral passes it OK. @ArthurZucker, @gante - please see if it is good to merge now

poedator avatar May 09 '24 22:05 poedator

Just waiting for the commits to be push to check the done parts and we can merge

ArthurZucker avatar May 10 '24 11:05 ArthurZucker