transformers
transformers copied to clipboard
Add torch.compile for Mistral
As suggested by the title, this PR attempts to add torch.compile support for mistral, and this is a not-ready-to-merge PR, it tries to replicate what has been done in Llama to Mistral considering the similar arch
- [ ] Add Sliding Window Cache
- [ ] Use Static Cache/Sliding Window Cache for torch.compile
- [ ] Moves attention mask related logic inside _update_causal_mask
- [ ] Modify _prepare_input_for_generation to make _generate work
Still not sure about the 4d mask logic inside _update_causal_mask, please take a look at it
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
Also the list of models that support static cache in the doc probably need an update
BTW mistral will nee a SlidingWindowCache
based on the implementation of RecurrentGemma
!
BTW mistral will nee a
SlidingWindowCache
based on the implementation ofRecurrentGemma
!
from my understanding, SlidingWindowCache is for memory efficiency, actually I wonder how SlidingWindowCache would address the issue of get_seq_length where the current solution is relying on performing a non-zero check on all previous tokens, another solution is to use cache_position so that we don't have to call past_key_values.get_seq_length(), but this would require cache position is always passed in model.forward right?
Yes, we need to rely on cache positions that have to be passed to the model's forward. They can be initialized like in Llama. NOte that it's not breaking for the dyunamic cache because it will ignore the extra kwargs.
def _update_cache(self, key_states, value_states, **cache_kwargs):
"""
torch.compile compatible sliding window.
Computes the `indices` based on `cache_position >= self.config.attention_window_size - 1`.
The `to_shift` is only true once we are above attention_window_size. Thus with `attention_window_size==64`:
indices = (slicing + to_shift[-1].int()-1) % self.config.attention_window_size
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59, 60, 61, 62, 63, 0])
We overwrite the cache using these, then we always write at cache_position (clamped to `attention_window_size`)
"""
cache_position = cache_kwargs.get("cache_position")
if cache_position.shape[0] > self.config.attention_window_size:
# int indexing -> device sync? in compile, use tensor
k_out = key_states[:, :, -self.config.attention_window_size :, :]
v_out = value_states[:, :, -self.config.attention_window_size :, :]
else:
slicing = torch.ones(
self.config.attention_window_size, dtype=torch.long, device=value_states.device
).cumsum(0)
cache_position = cache_position.clamp(0, self.config.attention_window_size - 1)
to_shift = cache_position >= self.config.attention_window_size - 1
indices = (slicing + to_shift[-1].int() - 1) % self.config.attention_window_size
k_out, v_out = self.key_states.to(key_states.device), self.value_states.to(value_states.device)
k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
self.key_states, self.value_states = k_out, v_out
return k_out, v_out
updating the cache should be like this
YEah it's hard to debug. merge from main and try to add a print to check where it's failing! Most probably a copied from that is not well placed
Hi Aurthur, I have add support for Sliding Window Cache, and please take a look at its implementation and also the _update_causal_mask implementation, I have add my thoughts as comments
Good work
- for all the copied from that were removed, we need to use on of the model as the new base (mixtral for example)
- as @gante said, let's add phi to the lsit of model slow tested
- let's revert some styling on setup.py and examples run object detection
Hi Authur, I made some modifications, and I still can't quite get what we should do to solve this copy consistency issue. I have changed the base to Mixtral
for most models that refered to Mistral, but there are still some cases where it doesn't quite fit because Mixtral
has itself moe related logic, and I also see this Ignore copy
bandaid-stuff everywhere, it becomes very unclear that what the exact thing I need to do is, is it add Ignore copy
or is it make modifications to make sources match and which source should be the standard if they disagree ? and make fix-copies
does not help much because you can't just use it without being really sure that you actually need exactly the same thing, otherwise it just break stuffs!
Could you please explain more about the second point about phi model? I don't quite get it.
Good work on adding support for compile! 🔥 all is left to do is benchmark to get the expected number of token per second on A100s! 🤗 cc @ydshieh this will be merged today / tomorrow!
Cool! We have 3 models to run the benchmark!
@ArthurZucker Don't forget our run-slow
feature 🙏
@zhenglongjiepheonix Could you push an empty commit with message [run-slow] mistral
? Thank you 🤗
@ArthurZucker Don't forget our
run-slow
feature 🙏@zhenglongjiepheonix Could you push an empty commit with message
[run-slow] mistral
? Thank you 🤗
The slow CI is failing on this, I think it's some annotation grammars not supported in 3.8 ? but it seems not related with my changes @ArthurZucker @ydshieh
/usr/local/lib/python3.8/dist-packages/_pytest/config/__init__.py:331: PluggyTeardownRaisedWarning: A plugin raised an exception during an old-style hookwrapper teardown.
Plugin: helpconfig, Hook: pytest_cmdline_parse
ConftestImportFailure: RuntimeError: Failed to import transformers.integrations.integration_utils because of the following error (look up to see its traceback):
Failed to import transformers.modeling_utils because of the following error (look up to see its traceback):
Failed to import transformers.generation.utils because of the following error (look up to see its traceback):
unsupported operand type(s) for |: '_GenericAlias' and 'NoneType' (from /transformers/conftest.py)
For more information see https://pluggy.readthedocs.io/en/stable/api_reference.html#pluggy.PluggyTeardownRaisedWarning
config = pluginmanager.hook.pytest_cmdline_parse(
ImportError while loading conftest '/transformers/conftest.py'.
conftest.py:26: in <module>
from transformers.testing_utils import HfDoctestModule, HfDocTestParser
src/transformers/testing_utils.py:45: in <module>
from .integrations import (
src/transformers/utils/import_utils.py:15[20](https://github.com/huggingface/transformers/actions/runs/9083794520/job/24963272400?pr=30642#step:11:21): in __getattr__
module = self._get_module(self._class_to_module[name])
src/transformers/utils/import_utils.py:1532: in _get_module
raise RuntimeError(
E RuntimeError: Failed to import transformers.integrations.integration_utils because of the following error (look up to see its traceback):
E Failed to import transformers.modeling_utils because of the following error (look up to see its traceback):
E Failed to import transformers.generation.utils because of the following error (look up to see its traceback):
E unsupported operand type(s) for |: '_GenericAlias' and 'NoneType'
Update: it works fine after I modify the annotations in SlidingWindowCache
yeah, but the whole CI is not run. We will check with @ArthurZucker tomorrow on what happens.
Currently there are some issues related with Mistral tests, since my dev is based on A100, I run these tests on colab T4 using the current main branch, @ArthurZucker @ydshieh :
-
MistralIntegrationTest.test_model_7b_generation
this test fails on main branch on T4 -
MistralIntegrationTest.test_speculative_generation
this test fails on main branch on T4 and A100, related to #30778 -
MistralIntegrationTest.test_model_7b_logits
this test fails on main branch on T4 and A100 -
Mask4DTestHard
related tests this goes OOM on a single T4, from #30348 - the torch version CI uses < 2.3, which does not trigger Static Cache related tests
Yes. For CI reports, you can check the internal channel
https://huggingface.slack.com/archives/C06LR9PQA00/p1715686275921759
we have some failing tests, including the one you mentioned.
Static Cache related tests
This will be run in another workflow file
https://github.com/huggingface/transformers/actions/workflows/push-important-models.yml
where we will change to use torch 2.3 tomorrow (likely)
If except the 4 mentioned failing tests, all other are passing with this PR + test_compile_static_cache
is passing on a A10 with torch 2.3, it's OK from my side (in terms of CI)
If except the 4 mentioned failing tests, all other are passing with this PR +
test_compile_static_cache
is passing on a A10 with torch 2.3, it's OK from my side (in terms of CI)
I don't see it running on the A10-based workflow, is it because it does not trigger on PR?
If except the 4 mentioned failing tests, all other are passing with this PR +
test_compile_static_cache
is passing on a A10 with torch 2.3, it's OK from my side (in terms of CI)I don't see it running on the A10-based workflow, is it because it does not trigger on PR?
It's because the torch version there should use torch 2.3 but it is currently torch 2.2. I will update the docker files to use torch 2.3.
make sure you rebase on
~~@ArthurZucker Rebase won't fix the issue (those are failing on main
too~~
(sorry, I am talking about https://github.com/huggingface/transformers/pull/30642#issuecomment-2111106008)
unsupported operand type(s) for |: '_GenericAlias' and 'NoneType'
is python 3.8 not fine with some typing that was added probably
If except the 4 mentioned failing tests, all other are passing with this PR +
test_compile_static_cache
is passing on a A10 with torch 2.3, it's OK from my side (in terms of CI)I don't see it running on the A10-based workflow, is it because it does not trigger on PR?
It's because the torch version there should use torch 2.3 but it is currently torch 2.2. I will update the docker files to use torch 2.3.
I think A10 workflow only trigger on push on main, not on PR ? I mean my commits on this pr are not triggering the test, do I need to directly work on the transformers repo instead of my forked repo to run the test
name: Slow tests on important models (on Push - A10)
on:
push:
branches: [ main ]
I think A10 workflow only trigger on push on main, not on PR ? Yes.
You could run workflow manually
https://github.com/huggingface/transformers/actions/workflows/ssh-runner.yml
See DM on slaick
Ok, now the problem is both llama and mistral are failing the compile static cache tests on A10 because of the same error:
============================= FAILURES SHORT STACK =============================
________________ LlamaIntegrationTest.test_compile_static_cache ________________
msg = 'hasattr ConstDictVariable to'
def unimplemented(msg: str) -> NoReturn:
assert msg != os.environ.get("BREAK", False)
> raise Unsupported(msg)
E torch._dynamo.exc.Unsupported: hasattr ConstDictVariable to
E
E from user code:
E File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/external_utils.py", line 36, in inner
E return fn(*args, **kwargs)
E File "/usr/local/lib/python3.8/dist-packages/accelerate/hooks.py", line 161, in new_forward
E args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
E File "/usr/local/lib/python3.8/dist-packages/accelerate/hooks.py", line 356, in pre_forward
E return send_to_device(args, self.execution_device), send_to_device(
E File "/usr/local/lib/python3.8/dist-packages/accelerate/utils/operations.py", line 148, in send_to_device
E if is_torch_tensor(tensor) or hasattr(tensor, "to"):
E
E Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
E
E
E You can suppress this exception and fall back to eager by setting:
E import torch._dynamo
E torch._dynamo.config.suppress_errors = True
They work fine on my A100 dev and colab T4 env, but here this error seems unrelated with GPUs, rather it's something related with software, I can't reproduce the error, even on an A10 aws machine with python==3.8:
Ok, now the problem is both llama and mistral are failing the compile static cache tests on A10 because of the same error:
============================= FAILURES SHORT STACK ============================= ________________ LlamaIntegrationTest.test_compile_static_cache ________________ msg = 'hasattr ConstDictVariable to' def unimplemented(msg: str) -> NoReturn: assert msg != os.environ.get("BREAK", False) > raise Unsupported(msg) E torch._dynamo.exc.Unsupported: hasattr ConstDictVariable to E E from user code: E File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/external_utils.py", line 36, in inner E return fn(*args, **kwargs) E File "/usr/local/lib/python3.8/dist-packages/accelerate/hooks.py", line 161, in new_forward E args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs) E File "/usr/local/lib/python3.8/dist-packages/accelerate/hooks.py", line 356, in pre_forward E return send_to_device(args, self.execution_device), send_to_device( E File "/usr/local/lib/python3.8/dist-packages/accelerate/utils/operations.py", line 148, in send_to_device E if is_torch_tensor(tensor) or hasattr(tensor, "to"): E E Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information E E E You can suppress this exception and fall back to eager by setting: E import torch._dynamo E torch._dynamo.config.suppress_errors = True
They work fine on my A100 dev and colab T4 env, but here this error seems unrelated with GPUs, rather it's something related with software, I can't reproduce the error, even on an A10 aws machine with python==3.8:
the interesting fact is, test_compile_static_cache
passes with itself along in CI, however if run after test_compile_sliding_window_cache
, then it fails because some memory-related strategies in accelerate
make dynamo unhappy @ArthurZucker @ydshieh
Yep accelerate does not support compile yet
I merged from current main and again did slow tests on my dev and aws a10 machine, I believe this PR is good to merge now @ArthurZucker @ydshieh
Please merge this when available @ArthurZucker , for I don't have write access to the library