transformers
transformers copied to clipboard
Make StaticCache configurable at model construct time
What does this PR do?
This PR is to address #32500 for "Export to ExecuTorch"
Enable the ability to load a model with options to statically config the model using StaticCache
:
model = AutoModelForCausalLM.from_pretrained(
hf_model_repo,
attn_implementation="sdpa",
generation_config=GenerationConfig(
use_cache=True,
cache_implementation=cache_implementation,
max_length=max_cache_len,
cache_config={
"batch_size": batch_size,
"max_cache_len": max_cache_len,
},
),
)
Create a new integration point for ExecuTorch
at transformers/integrations/executorch.py
and hosts the wrapper module class and util convert_and_export
there.
The test model gemma-2b
is naturally exportable via convert_and_export
with StaticCache
.
The test model gemma-2b
is also lowerable and runnable via ExecuTorch
! Checkout https://github.com/pytorch/executorch/pull/4723 in ExecuTorch
to repro
Before submitting
- [x] Did you read the contributor guideline, Pull Request section?
- [x] Was this discussed/approved via a Github issue or the forum? #32500
- [x] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [x] Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
@ArthurZucker @amyeroberts @gante