transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Add torch.compile for Mistral

Open zhenglongjiepheonix opened this issue 9 months ago • 8 comments

As suggested by the title, this PR attempts to add torch.compile support for mistral, and this is a not-ready-to-merge PR, it tries to replicate what has been done in Llama to Mistral considering the similar arch

  • [ ] Add Sliding Window Cache
  • [ ] Use Static Cache/Sliding Window Cache for torch.compile
  • [ ] Moves attention mask related logic inside _update_causal_mask
  • [ ] Modify _prepare_input_for_generation to make _generate work

Still not sure about the 4d mask logic inside _update_causal_mask, please take a look at it

zhenglongjiepheonix avatar May 03 '24 16:05 zhenglongjiepheonix

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Also the list of models that support static cache in the doc probably need an update

ArthurZucker avatar May 06 '24 07:05 ArthurZucker

BTW mistral will nee a SlidingWindowCache based on the implementation of RecurrentGemma!

ArthurZucker avatar May 06 '24 15:05 ArthurZucker

BTW mistral will nee a SlidingWindowCache based on the implementation of RecurrentGemma!

from my understanding, SlidingWindowCache is for memory efficiency, actually I wonder how SlidingWindowCache would address the issue of get_seq_length where the current solution is relying on performing a non-zero check on all previous tokens, another solution is to use cache_position so that we don't have to call past_key_values.get_seq_length(), but this would require cache position is always passed in model.forward right?

zhenglongjiepheonix avatar May 07 '24 04:05 zhenglongjiepheonix

Yes, we need to rely on cache positions that have to be passed to the model's forward. They can be initialized like in Llama. NOte that it's not breaking for the dyunamic cache because it will ignore the extra kwargs.

    def _update_cache(self, key_states, value_states, **cache_kwargs):
        """
        torch.compile compatible sliding window.
        Computes the `indices` based on `cache_position >= self.config.attention_window_size - 1`.
        The `to_shift` is only true once we are above attention_window_size. Thus with `attention_window_size==64`:

        indices = (slicing + to_shift[-1].int()-1) % self.config.attention_window_size
        tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
                19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
                37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
                55, 56, 57, 58, 59, 60, 61, 62, 63,  0])

        We overwrite the cache using these, then we always write at cache_position (clamped to `attention_window_size`)
        """
        cache_position = cache_kwargs.get("cache_position")
        if cache_position.shape[0] > self.config.attention_window_size:
            # int indexing -> device sync? in compile, use tensor
            k_out = key_states[:, :, -self.config.attention_window_size :, :]
            v_out = value_states[:, :, -self.config.attention_window_size :, :]
        else:
            slicing = torch.ones(
                self.config.attention_window_size, dtype=torch.long, device=value_states.device
            ).cumsum(0)
            cache_position = cache_position.clamp(0, self.config.attention_window_size - 1)
            to_shift = cache_position >= self.config.attention_window_size - 1
            indices = (slicing + to_shift[-1].int() - 1) % self.config.attention_window_size

            k_out, v_out = self.key_states.to(key_states.device), self.value_states.to(value_states.device)
            k_out = k_out[:, :, indices]
            v_out = v_out[:, :, indices]

            k_out[:, :, cache_position] = key_states
            v_out[:, :, cache_position] = value_states

        self.key_states, self.value_states = k_out, v_out
        return k_out, v_out

updating the cache should be like this

ArthurZucker avatar May 07 '24 09:05 ArthurZucker

YEah it's hard to debug. merge from main and try to add a print to check where it's failing! Most probably a copied from that is not well placed

ArthurZucker avatar May 07 '24 09:05 ArthurZucker

Hi Aurthur, I have add support for Sliding Window Cache, and please take a look at its implementation and also the _update_causal_mask implementation, I have add my thoughts as comments

zhenglongjiepheonix avatar May 09 '24 04:05 zhenglongjiepheonix

Good work

  • for all the copied from that were removed, we need to use on of the model as the new base (mixtral for example)
  • as @gante said, let's add phi to the lsit of model slow tested
  • let's revert some styling on setup.py and examples run object detection

Hi Authur, I made some modifications, and I still can't quite get what we should do to solve this copy consistency issue. I have changed the base to Mixtral for most models that refered to Mistral, but there are still some cases where it doesn't quite fit because Mixtral has itself moe related logic, and I also see this Ignore copy bandaid-stuff everywhere, it becomes very unclear that what the exact thing I need to do is, is it add Ignore copy or is it make modifications to make sources match and which source should be the standard if they disagree ? and make fix-copies does not help much because you can't just use it without being really sure that you actually need exactly the same thing, otherwise it just break stuffs!

Could you please explain more about the second point about phi model? I don't quite get it.

zhenglongjiepheonix avatar May 10 '24 02:05 zhenglongjiepheonix

Good work on adding support for compile! 🔥 all is left to do is benchmark to get the expected number of token per second on A100s! 🤗 cc @ydshieh this will be merged today / tomorrow!

Cool! We have 3 models to run the benchmark!

ydshieh avatar May 14 '24 16:05 ydshieh

@ArthurZucker Don't forget our run-slow feature 🙏

@zhenglongjiepheonix Could you push an empty commit with message [run-slow] mistral? Thank you 🤗

ydshieh avatar May 14 '24 16:05 ydshieh

@ArthurZucker Don't forget our run-slow feature 🙏

@zhenglongjiepheonix Could you push an empty commit with message [run-slow] mistral? Thank you 🤗

The slow CI is failing on this, I think it's some annotation grammars not supported in 3.8 ? but it seems not related with my changes @ArthurZucker @ydshieh

/usr/local/lib/python3.8/dist-packages/_pytest/config/__init__.py:331: PluggyTeardownRaisedWarning: A plugin raised an exception during an old-style hookwrapper teardown.
Plugin: helpconfig, Hook: pytest_cmdline_parse
ConftestImportFailure: RuntimeError: Failed to import transformers.integrations.integration_utils because of the following error (look up to see its traceback):
Failed to import transformers.modeling_utils because of the following error (look up to see its traceback):
Failed to import transformers.generation.utils because of the following error (look up to see its traceback):
unsupported operand type(s) for |: '_GenericAlias' and 'NoneType' (from /transformers/conftest.py)
For more information see https://pluggy.readthedocs.io/en/stable/api_reference.html#pluggy.PluggyTeardownRaisedWarning
  config = pluginmanager.hook.pytest_cmdline_parse(
ImportError while loading conftest '/transformers/conftest.py'.
conftest.py:26: in <module>
    from transformers.testing_utils import HfDoctestModule, HfDocTestParser
src/transformers/testing_utils.py:45: in <module>
    from .integrations import (
src/transformers/utils/import_utils.py:15[20](https://github.com/huggingface/transformers/actions/runs/9083794520/job/24963272400?pr=30642#step:11:21): in __getattr__
    module = self._get_module(self._class_to_module[name])
src/transformers/utils/import_utils.py:1532: in _get_module
    raise RuntimeError(
E   RuntimeError: Failed to import transformers.integrations.integration_utils because of the following error (look up to see its traceback):
E   Failed to import transformers.modeling_utils because of the following error (look up to see its traceback):
E   Failed to import transformers.generation.utils because of the following error (look up to see its traceback):
E   unsupported operand type(s) for |: '_GenericAlias' and 'NoneType'

Update: it works fine after I modify the annotations in SlidingWindowCache

zhenglongjiepheonix avatar May 14 '24 18:05 zhenglongjiepheonix

yeah, but the whole CI is not run. We will check with @ArthurZucker tomorrow on what happens.

ydshieh avatar May 14 '24 18:05 ydshieh

Currently there are some issues related with Mistral tests, since my dev is based on A100, I run these tests on colab T4 using the current main branch, @ArthurZucker @ydshieh :

  1. MistralIntegrationTest.test_model_7b_generation this test fails on main branch on T4
  2. MistralIntegrationTest.test_speculative_generation this test fails on main branch on T4 and A100, related to #30778
  3. MistralIntegrationTest.test_model_7b_logits this test fails on main branch on T4 and A100
  4. Mask4DTestHard related tests this goes OOM on a single T4, from #30348
  5. the torch version CI uses < 2.3, which does not trigger Static Cache related tests

zhenglongjiepheonix avatar May 14 '24 20:05 zhenglongjiepheonix

Yes. For CI reports, you can check the internal channel

https://huggingface.slack.com/archives/C06LR9PQA00/p1715686275921759

we have some failing tests, including the one you mentioned.

Static Cache related tests

This will be run in another workflow file

https://github.com/huggingface/transformers/actions/workflows/push-important-models.yml

where we will change to use torch 2.3 tomorrow (likely)

ydshieh avatar May 14 '24 21:05 ydshieh

If except the 4 mentioned failing tests, all other are passing with this PR + test_compile_static_cache is passing on a A10 with torch 2.3, it's OK from my side (in terms of CI)

ydshieh avatar May 14 '24 21:05 ydshieh

If except the 4 mentioned failing tests, all other are passing with this PR + test_compile_static_cache is passing on a A10 with torch 2.3, it's OK from my side (in terms of CI)

I don't see it running on the A10-based workflow, is it because it does not trigger on PR?

zhenglongjiepheonix avatar May 14 '24 21:05 zhenglongjiepheonix

If except the 4 mentioned failing tests, all other are passing with this PR + test_compile_static_cache is passing on a A10 with torch 2.3, it's OK from my side (in terms of CI)

I don't see it running on the A10-based workflow, is it because it does not trigger on PR?

It's because the torch version there should use torch 2.3 but it is currently torch 2.2. I will update the docker files to use torch 2.3.

ydshieh avatar May 15 '24 05:05 ydshieh

make sure you rebase on

~~@ArthurZucker Rebase won't fix the issue (those are failing on main too~~

(sorry, I am talking about https://github.com/huggingface/transformers/pull/30642#issuecomment-2111106008)

ydshieh avatar May 15 '24 06:05 ydshieh

unsupported operand type(s) for |: '_GenericAlias' and 'NoneType' is python 3.8 not fine with some typing that was added probably

ArthurZucker avatar May 15 '24 07:05 ArthurZucker

If except the 4 mentioned failing tests, all other are passing with this PR + test_compile_static_cache is passing on a A10 with torch 2.3, it's OK from my side (in terms of CI)

I don't see it running on the A10-based workflow, is it because it does not trigger on PR?

It's because the torch version there should use torch 2.3 but it is currently torch 2.2. I will update the docker files to use torch 2.3.

I think A10 workflow only trigger on push on main, not on PR ? I mean my commits on this pr are not triggering the test, do I need to directly work on the transformers repo instead of my forked repo to run the test

name: Slow tests on important models (on Push - A10)

on:
  push:
    branches: [ main ]

zhenglongjiepheonix avatar May 15 '24 21:05 zhenglongjiepheonix

I think A10 workflow only trigger on push on main, not on PR ? Yes.

You could run workflow manually

https://github.com/huggingface/transformers/actions/workflows/ssh-runner.yml

See DM on slaick

ydshieh avatar May 16 '24 05:05 ydshieh

Ok, now the problem is both llama and mistral are failing the compile static cache tests on A10 because of the same error:

============================= FAILURES SHORT STACK =============================
________________ LlamaIntegrationTest.test_compile_static_cache ________________
msg = 'hasattr ConstDictVariable to'
    def unimplemented(msg: str) -> NoReturn:
        assert msg != os.environ.get("BREAK", False)
>       raise Unsupported(msg)
E       torch._dynamo.exc.Unsupported: hasattr ConstDictVariable to
E       
E       from user code:
E          File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/external_utils.py", line 36, in inner
E           return fn(*args, **kwargs)
E         File "/usr/local/lib/python3.8/dist-packages/accelerate/hooks.py", line 161, in new_forward
E           args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
E         File "/usr/local/lib/python3.8/dist-packages/accelerate/hooks.py", line 356, in pre_forward
E           return send_to_device(args, self.execution_device), send_to_device(
E         File "/usr/local/lib/python3.8/dist-packages/accelerate/utils/operations.py", line 148, in send_to_device
E           if is_torch_tensor(tensor) or hasattr(tensor, "to"):
E       
E       Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
E       
E       
E       You can suppress this exception and fall back to eager by setting:
E           import torch._dynamo
E           torch._dynamo.config.suppress_errors = True

They work fine on my A100 dev and colab T4 env, but here this error seems unrelated with GPUs, rather it's something related with software, I can't reproduce the error, even on an A10 aws machine with python==3.8: Screen Shot 2024-05-16 at 17 25 19

zhenglongjiepheonix avatar May 16 '24 18:05 zhenglongjiepheonix

Ok, now the problem is both llama and mistral are failing the compile static cache tests on A10 because of the same error:

============================= FAILURES SHORT STACK =============================
________________ LlamaIntegrationTest.test_compile_static_cache ________________
msg = 'hasattr ConstDictVariable to'
    def unimplemented(msg: str) -> NoReturn:
        assert msg != os.environ.get("BREAK", False)
>       raise Unsupported(msg)
E       torch._dynamo.exc.Unsupported: hasattr ConstDictVariable to
E       
E       from user code:
E          File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/external_utils.py", line 36, in inner
E           return fn(*args, **kwargs)
E         File "/usr/local/lib/python3.8/dist-packages/accelerate/hooks.py", line 161, in new_forward
E           args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
E         File "/usr/local/lib/python3.8/dist-packages/accelerate/hooks.py", line 356, in pre_forward
E           return send_to_device(args, self.execution_device), send_to_device(
E         File "/usr/local/lib/python3.8/dist-packages/accelerate/utils/operations.py", line 148, in send_to_device
E           if is_torch_tensor(tensor) or hasattr(tensor, "to"):
E       
E       Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
E       
E       
E       You can suppress this exception and fall back to eager by setting:
E           import torch._dynamo
E           torch._dynamo.config.suppress_errors = True

They work fine on my A100 dev and colab T4 env, but here this error seems unrelated with GPUs, rather it's something related with software, I can't reproduce the error, even on an A10 aws machine with python==3.8: Screen Shot 2024-05-16 at 17 25 19

the interesting fact is, test_compile_static_cache passes with itself along in CI, however if run after test_compile_sliding_window_cache, then it fails because some memory-related strategies in accelerate make dynamo unhappy @ArthurZucker @ydshieh

zhenglongjiepheonix avatar May 16 '24 23:05 zhenglongjiepheonix

Yep accelerate does not support compile yet

ArthurZucker avatar May 17 '24 14:05 ArthurZucker

I merged from current main and again did slow tests on my dev and aws a10 machine, I believe this PR is good to merge now @ArthurZucker @ydshieh

zhenglongjiepheonix avatar May 17 '24 16:05 zhenglongjiepheonix

Please merge this when available @ArthurZucker , for I don't have write access to the library

zhenglongjiepheonix avatar May 20 '24 14:05 zhenglongjiepheonix