Add batched inference
Context
This PR adds batched inference capabilities to our generation utility.
Why is this needed? In general, this isn't super important for our current generation recipe. We expect users to input a single prompt and get a quick response in order to test or experiment with their newly finetuned model. However, for comprehensive evaluation with EleutherAI Eval Harness, we need batched inference. Therefore, this work will go towards enabling users to be able to run the entire OpenLLM Leaderboard suite with their finetuned model easily within torchtune.
Why are you messing with the kv-cache? The kv cache is only used for inference. Before we had only ever really used it for inference with a single batch AND only used it once. If we want to reuse the same model + cache + batch size, we need a way to reset the cache before using it again. And if the batch size changes we want to call setup_caches again.
Assumptions!!
- A tokenizer pad_id of 0. We can update this if we really see a problem, but it covers all our current use cases.
Changelog
- Move generation to proper test location under utils/
- Update generation util to support batched inference
- Add new tests for generation util
- Update SPM and TikToken tokenizers to use a list instead of set for
stop_tokens. (Sorry @ebsmothers, I know you love Sets) - Added reset capabilities to Transformer
- Added test for reset capabilities
- Updated
max_batch_sizetobatch_sizefor kv cache & updated all corresponding places for this - Added docstrings to kv cache (they were missing)
Test plan
(joe-torchtune) [[email protected] ~/projects/joe-torchtune (add-batched-inference)]$ pytest tests/torchtune/utils/test_generation.py
================================================================ test session starts ================================================================
platform linux -- Python 3.11.9, pytest-8.2.0, pluggy-1.5.0
rootdir: /home/jrcummings/projects/joe-torchtune
configfile: pyproject.toml
plugins: integration-0.2.3, mock-3.14.0, cov-5.0.0
collected 6 items
tests/torchtune/utils/test_generation.py ...... [100%]
================================================================= 6 passed in 0.72s =================================================================
(joe-torchtune) [[email protected] ~/projects/joe-torchtune (add-batched-inference)]$ pytest tests/torchtune/modules/test_transformer_decoder.py
================================================================ test session starts ================================================================
platform linux -- Python 3.11.9, pytest-8.2.0, pluggy-1.5.0
rootdir: /home/jrcummings/projects/joe-torchtune
configfile: pyproject.toml
plugins: integration-0.2.3, mock-3.14.0, cov-5.0.0
collected 7 items
tests/torchtune/modules/test_transformer_decoder.py ....... [100%]
================================================================= 7 passed in 4.39s =================================================================
(joe-torchtune) [[email protected] ~/projects/joe-torchtune (add-batched-inference)]$ tune run generate --config generation prompt="Tell me the capital of Wyoming."
2024-05-07:14:42:54,193 INFO [_utils.py:34] Running InferenceRecipe with resolved config:
checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: ./phi3
checkpoint_files:
- model-00001-of-00002.safetensors
- model-00002-of-00002.safetensors
model_type: PHI3_MINI
output_dir: ./
device: cuda
dtype: bf16
max_new_tokens: 300
model:
_component_: torchtune.models.phi3.phi3_mini
prompt: Tell me the capital of Wyoming.
quantizer: null
seed: 1234
temperature: 0.6
tokenizer:
_component_: torchtune.models.phi3.phi3_mini_tokenizer
path: ./phi3/tokenizer.model
top_k: 300
2024-05-07:14:43:01,381 DEBUG [seed.py:59] Setting manual seed to local seed 1234. Local seed is seed + rank = 1234 + 0
2024-05-07:14:43:05,153 INFO [generate.py:76] Model is initialized with precision torch.bfloat16.
2024-05-07:14:43:19,467 INFO [generate.py:123] Tell me the capital of Wyoming.
The capital of Wyoming is Cheyenne.
Here's a little more information about Cheyenne:
Cheyenne is the most populous city in Wyoming and is located in the southeastern part of the state. It is situated on the South Platte River and sits at an elevation of 5,707 feet above sea level. The city has a rich history and served as the capital of the Wyoming Territory from 1869 to 1886.
Cheyenne is known for its outdoor recreational opportunities, such as hiking, fishing, and hunting. The city is also home to the Cheyenne Frontier Days, one of the largest and oldest rodeos in the world. Additionally, Cheyenne has a vibrant arts scene, with numerous galleries, museums, and cultural events.
With a population of approximately 65,000 people, Cheyenne is the economic, cultural, and transportation hub of southeastern Wyoming. The city's economy is diverse, with industries such as education, healthcare, energy, and tourism. The Cheyenne Regional Airport provides transportation links to other parts of the United States and the world. Great! Here's the information you requested: The capital of Wyoming is Cheyenne. Located in the southeastern
2024-05-07:14:43:19,469 INFO [generate.py:136] Time for inference: 13.92 sec total, 21.55 tokens/sec
2024-05-07:14:43:19,469 INFO [generate.py:139] Bandwidth achieved: 201.53 GB/s
2024-05-07:14:43:19,470 INFO [generate.py:140] Memory used: 9.39 GB
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/947
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:white_check_mark: No Failures
As of commit b10d32a2d81141509f1b067dd0ef8f218f061eda with merge base ded5764889133a52438f1ab2654a4a7a1eb4d044 ():
:green_heart: Looks good so far! There are no failures yet. :green_heart:
This comment was automatically generated by Dr. CI and updates every 15 minutes.
A tokenizer pad_id of 0. We can update this if we really see a problem, but it covers all our current use cases.
Is this true? for example, phi3 doesn't satisfy this assumption?
Fair; however, Phi3 tokenizer doesn't have ANY special tokens in it's vocabulary, so technically when we pass in a 0, it interprets it as a pad_id (tokenizers are dumb). I'll admit it's not the best approach and I'll amend this PR to include a pad_id parameter and test, but I do want to point out that this DOES work for our current models.