PTQ for `generate_v2`
Context
What is the purpose of this PR? Is it to
- [x] add a new feature
- [ ] fix a bug
- [ ] update tests and/or documentation
- [ ] other (please add here)
This PR adds post-training quantization support for generate_v2 via torchao. It is tested only for text-models, specifically Llama2.
Why did you change the way quantization APIs are called?
Good catch - notably I made it so that instead of creating a Quantizer class and having that quantize the model, I opted to use the quantize_ API from torchao and instantiate a quantization method instead. I did this for two reasons:
- Simplifies our recipe and codebase.
- It more consistent with the usage that torchao seems to be pushing. We want the UX to be the same whether someone is quantizing a model here or directly with torchao APIs
Does this work for vision models? Technically, it runs, but we haven't fixed the torch.compile graph breaks in the Llama3.2 V model so it doesn't speed anything up. Therefore, I will not be including this in the default config for llama3.2V.
Why is it actually slower for the entire first run? My assumption is that compile is the culprit here. Once everything has run once, the model compilation is pulled from the compile cache and things are actually faster. Still, quantized generation like this is typically better for longer responses where the benefit is really clear. cc @andrewor14 if my intuition is correct here.
This DOES NOT work for PTQ a QAT model. This will be added in a follow-up.
Changelog
- Implement PTQ in generate_v2
- Clean up some of the variables in generate_v2 to make things public
- Added additional timing to split between first token and rest of tokens
- Update llama2/generation_v2 to support quantization
- Added a GPU test for quantized generation :)
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
- [x] run pre-commit hooks and linters (make sure you've first installed via
pre-commit install) - [x] add unit tests for any new functionality
- [ ] update docstrings for any new or updated methods or classes
- [x] run unit tests via
pytest tests - [x] run recipe tests via
pytest tests -m integration_test - [x] manually run any new or modified recipes with sufficient proof of correctness
- [x] include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)
All testing done with torchao v0.6.1 and torch 2.5.1
Recipe without PTQ:
(joe-torchtune-2) [[email protected] ~/projects/joe-torchtune (add-quantize-generate-v2)]$ tune run dev/generate_v2 --config llama2/generation_v2
Running InferenceRecipe with resolved config:
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-2-7b-chat-hf
checkpoint_files:
- pytorch_model-00001-of-00002.bin
- pytorch_model-00002-of-00002.bin
model_type: LLAMA2
output_dir: ./
device: cuda
dtype: bf16
log_level: INFO
max_new_tokens: 500
model:
_component_: torchtune.models.llama2.llama2_7b
prompt:
system: You are a helpful and creative AI assistant.
user: What is the capital of France?
seed: 1234
temperature: 0.6
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
max_seq_len: 2048
path: /tmp/Llama-2-7b-chat-hf/tokenizer.model
top_k: 300
Model was initialized with precision torch.bfloat16.
Time to generate first token: 0.45 sec
Oh, how delightful! *adjusts glasses* The capital of France is... *drumroll* Paris! 🇫🇷 Yes, the City of Light, the City of Love, the City of Art, and the City of Delicious Croissants. 🥐 Is there anything else I can help you with? 😊
Time for inference: 4.93 sec total, 17.04 tokens/sec
Bandwidth achieved: 235.60 GB/s
Max memory allocated: 13.95 GB
Recipe with PTQ (first run):
(joe-torchtune-2) [[email protected] ~/projects/joe-torchtune (add-quantize-generate-v2)]$ tune run dev/generate_v2 --config llama2/generation_v2
Running InferenceRecipe with resolved config:
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-2-7b-chat-hf
checkpoint_files:
- pytorch_model-00001-of-00002.bin
- pytorch_model-00002-of-00002.bin
model_type: LLAMA2
output_dir: ./
device: cuda
dtype: bf16
log_level: INFO
max_new_tokens: 500
model:
_component_: torchtune.models.llama2.llama2_7b
prompt:
system: You are a helpful and creative AI assistant.
user: What is the capital of France?
quantization_method:
_component_: torchao.quantization.quant_api.int4_weight_only
use_hqq: false
seed: 1234
temperature: 0.6
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
max_seq_len: 2048
path: /tmp/Llama-2-7b-chat-hf/tokenizer.model
top_k: 300
Model was initialized with precision torch.bfloat16.
Compiling model layers with torch.compile...
Time to generate first token: 18.98 sec
Ah, a question that is both simple and profound! *adjusts glasses* The capital of France, my dear human, is none other than the venerable city of Paris! 🇫🇷
But let me tell you more about this magnificent city, for it is a place of wonder and awe. Paris is home to some of the most iconic landmarks in the world, such as the Eiffel Tower, the Louvre Museum, and the Notre-Dame Cathedral. The city is also renowned for its exquisite cuisine, its vibrant art scene, and its unparalleled fashion.
And did you know that Paris is the City of Light? *winks* It is here that some of the greatest minds in history have come to seek inspiration and knowledge. From the likes of Victor Hugo to Emile Zola, and from Claude Monet to Pierre-Auguste Renoir, the City of Paris has been the birthplace of countless artistic masterpieces.
So there you have it, my dear human! The capital of France is none other than the enchanting city of Paris, a place that will capture your heart and imagination like no other. 💖
Time for inference: 27.66 sec total, 9.84 tokens/sec
Bandwidth achieved: 136.00 GB/s
Max memory allocated: 13.95 GB
Recipe with PTQ (second run):
(joe-torchtune-2) [[email protected] ~/projects/joe-torchtune (add-quantize-generate-v2)]$ tune run dev/generate_v2 --config llama2/generation_v2
Running InferenceRecipe with resolved config:
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-2-7b-chat-hf
checkpoint_files:
- pytorch_model-00001-of-00002.bin
- pytorch_model-00002-of-00002.bin
model_type: LLAMA2
output_dir: ./
device: cuda
dtype: bf16
log_level: INFO
max_new_tokens: 500
model:
_component_: torchtune.models.llama2.llama2_7b
prompt:
system: You are a helpful and creative AI assistant.
user: What is the capital of France?
quantization_method:
_component_: torchao.quantization.quant_api.int4_weight_only
use_hqq: false
seed: 1234
temperature: 0.6
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
max_seq_len: 2048
path: /tmp/Llama-2-7b-chat-hf/tokenizer.model
top_k: 300
Model was initialized with precision torch.bfloat16.
Compiling model layers with torch.compile...
Time to generate first token: 4.56 sec
Ah, a question that is both simple and profound! *adjusts glasses* The capital of France, my dear human, is none other than the venerable city of Paris! 🇫🇷
But let me tell you more about this magnificent city, for it is a place of wonder and awe. Paris is home to some of the most iconic landmarks in the world, such as the Eiffel Tower, the Louvre Museum, and the Notre-Dame Cathedral. The city is also renowned for its exquisite cuisine, its vibrant art scene, and its unparalleled fashion.
And did you know that Paris is the City of Light? *winks* It is here that some of the greatest minds in history have come to seek inspiration and knowledge. From the likes of Victor Hugo to Emile Zola, and from Claude Monet to Pierre-Auguste Renoir, the City of Paris has been the birthplace of countless artistic masterpieces.
So there you have it, my dear human! The capital of France is none other than the enchanting city of Paris, a place that will capture your heart and imagination like no other. 💖
Time for inference: 11.92 sec total, 22.82 tokens/sec
Bandwidth achieved: 315.49 GB/s
Max memory allocated: 13.95 GB
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it. Here is a docstring example and a tutorial example
- [x] I did not change any public API
- [ ] I have added an example to docs or docstrings
To-do
Fix failing GPU test. It's passing locally, so I'm not sure how to make it work on the remote runners:
(joe-torchtune-2) [[email protected] ~/projects/joe-torchtune (add-quantize-generate-v2)]$ python -m pytest tests/recipes/dev/test_generate_v2.py::TestGenerateV2::test_llama2_generate_with_quantization --with-integration
Expected artifacts for test run are:
small-ckpt-tune-03082024.pt
small-ckpt-meta-03082024.pt
small-ckpt-hf-03082024.pt
small-ckpt-tune-llama3-05052024.pt
small-ckpt-hf-reward-07122024.pt
tokenizer.model
tokenizer_llama3.model
File already exists locally: /tmp/test-artifacts/small-ckpt-tune-03082024.pt
File already exists locally: /tmp/test-artifacts/small-ckpt-meta-03082024.pt
File already exists locally: /tmp/test-artifacts/small-ckpt-hf-03082024.pt
File already exists locally: /tmp/test-artifacts/small-ckpt-tune-llama3-05052024.pt
File already exists locally: /tmp/test-artifacts/small-ckpt-hf-reward-07122024.pt
File already exists locally: /tmp/test-artifacts/tokenizer.model
File already exists locally: /tmp/test-artifacts/tokenizer_llama3.model
================================================================================================================ test session starts ================================================================================================================
platform linux -- Python 3.11.9, pytest-7.4.0, pluggy-1.5.0
rootdir: /home/jrcummings/projects/joe-torchtune
configfile: pyproject.toml
plugins: integration-0.2.3, mock-3.14.0, cov-5.0.0
collected 1 item
tests/recipes/dev/test_generate_v2.py . [100%]
================================================================================================================ 1 passed in 42.70s =================================================================================================================
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1866
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:x: 1 New Failure, 2 Cancelled Jobs
As of commit 0e08081eebce48e37d3dd56cb0b57dc67c6cdfc9 with merge base 890deab3029eef65f94cedb37fda14479f65f129 ():
NEW FAILURE - The following job has failed:
-
GPU tests / gpu_test (3.9, stable) (gh)
tests/recipes/dev/test_generate_v2.py::TestGenerateV2::test_llama2_generate_with_quantization
CANCELLED JOBS - The following jobs were cancelled. Please retry:
-
GPU tests / gpu_test (3.10, stable) (gh)
##[error]The operation was canceled. - GPU tests / gpu_test (3.11, stable) (gh)
This comment was automatically generated by Dr. CI and updates every 15 minutes.
Codecov Report
Attention: Patch coverage is 25.92593% with 20 lines in your changes missing coverage. Please review.
Project coverage is 23.90%. Comparing base (
97e857f) to head (0e08081). Report is 237 commits behind head on main.
Additional details and impacted files
@@ Coverage Diff @@
## main #1866 +/- ##
==========================================
+ Coverage 8.97% 23.90% +14.92%
==========================================
Files 305 357 +52
Lines 18166 21129 +2963
==========================================
+ Hits 1631 5050 +3419
+ Misses 16535 16079 -456
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
@felipemello1 @ebsmothers Will this not pass on PyTorch 2.5 b/c of the issue with CUDNN? This test passes locally on PyTorch v2.5.1.
Do we know when the patch will be released?
This looks overall sensible, but a few outstanding questions I have:
- What implications does this have for how we expose quantization APIs?
- What is going on with compile?
- Why is memory usage identical for non-PTQ, and PTQ? I guess because we're still peaking when we load weights in bf16, and we're measuring global max memory usage?
- Why is it so slow? Even the second run of PTQ takes 12s vs 5s for non-PTQ - the bump in toks/s doesn't seem to offset whatever else is slowing it down
- Noob q: given the above two points - max memory usage is identical and it takes longer... when would someone want to use this?
We probably don't need to answer all of these here but I think it'd help bring a lot of our quantization offerings in line if we can at least follow up on them.
- What implications does this have for how we expose quantization APIs?
I think the question is actually if we want to support PTQ APIs outside of torchao. If we do, we want want to opt for an approach like Hugging Face's wherein a config for a specific backend can be initialized. I'd argue that we probably don't want to b/c torchao already supports general quant, HQQ, and GPTQ (altho GPTQ is not available through the quantize_ API yet). Idk if this is too short sighted though.
- What is going on with compile?
Not sure I understand the question. It's always slow during warmup run.
- Why is memory usage identical for non-PTQ, and PTQ? I guess because we're still peaking when we load weights in bf16, and we're measuring global max memory usage?
Exactly.
- Why is it so slow? Even the second run of PTQ takes 12s vs 5s for non-PTQ - the bump in toks/s doesn't seem to offset whatever else is slowing it down
Not sure what is so slow, but I've reached out to the AO team to see if this is normal.
- Noob q: given the above two points - max memory usage is identical and it takes longer... when would someone want to use this?
An excellent question. I don't imagine anyone would want to use this recipe out of the box with quantization. However, it's a great playground for showing how easy it is to setup quantization with our models. The real benefit comes from serving this model somewhere so that you can compile + quant once and get continuous speed-ups for everything downstream. Also, if we end up having a super simple chat component, this would also demonstrate gains.
Are you seeing the slowdown for int4_weight_only specifically? That's surprising since we have an efficient tinygemm cuda kernel for that, and the model size should actually be 1/4 of the original bf16 model size (unlike int8_dynamic_activation_int4_weight). Also cc @jerryzh168 @HDCharles who did some benchmarking on this from the AO side
Are you seeing the slowdown for
int4_weight_onlyspecifically?
I tried both int4_weight_only and dynamic activation version and both had initial slowdowns for the entire first run, but afterwards ran faster.
Are you seeing the slowdown for
int4_weight_onlyspecifically?I tried both
int4_weight_onlyand dynamic activation version and both had initial slowdowns for the entire first run, but afterwards ran faster.
slower on the first run is expected I feel, since compile actually happens at the first run when it sees the real inputs, typically when we do benchmark there will be some warmup runs for compile to actually run and we'll benchmark the following runs
Are you seeing the slowdown for
int4_weight_onlyspecifically?I tried both
int4_weight_onlyand dynamic activation version and both had initial slowdowns for the entire first run, but afterwards ran faster.slower on the first run is expected I feel, since compile actually happens at the first run when it sees the real inputs, typically when we do benchmark there will be some warmup runs for compile to actually run and we'll benchmark the following runs
I know that compile happens at the first forward pass, but what I'm seeing is a slowdown for the entire first generation of outputs (see logs in the PR description. Is this expected?