transformers
transformers copied to clipboard
Offloaded KV Cache
What does this PR do?
Fixes #30704
This PR introduces OffloadedCache. This is a KV cache implementation that reduces GPU memory usage in exchange for more CPU memory usage and a small increase in generation time. During the forward passes in generate
, it only keeps two layers of KV cache on the device: the current layer and the next layer. All other layers are on the CPU and are prefetched/evicted as necessary.
It can be used by passing cache_implementation="offloaded"
in the GenerationConfig
like this:
gen_config = GenerationConfig(
cache_implementation="offloaded",
# other generation options such as
num_beams=4,
num_beam_groups=2,
num_return_sequences=4,
diversity_penalty=1.0,
max_new_tokens=50,
early_stopping=True,
)
outputs = model.generate(
inputs["input_ids"],
generation_config=gen_config
)
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [x] Did you read the contributor guideline, Pull Request section?
- [x] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case. Issue #30704
- [x] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [x] Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. @ArthurZucker @gante