torchtune
torchtune copied to clipboard
Integrate INT8 mixed-precision from torchao 0.7
Context
What is the purpose of this PR? Is it to
- [x] add a new feature
- [ ] fix a bug
- [ ] update tests and/or documentation
- [ ] other (please add here)
Recent INT8 mixed-precision work in torchao shows very promising results.
- Single device A100 -> ~40% speedup
- Single device 4090 -> ~70% speedup (consumer 4000 series GPUs have unusual speedup, which is nice)
- Works with FSDP2
Known major limitations
- Requires
torch.compile()
to enjoy speedup (to codegen efficient dynamic quantization code) - Input sizes should not vary too much, since it will trigger autotune for triton INT8 matmul kernel -> this only works well for PackedDataset w/ FlexAttention, since
seq_len
is static. - ~Does not work with
training.load_from_full_model_state_dict()
-> cannot integrate with distributed recipes atm.~ -> solved by using module-swap UX instead. Pending pytorch/ao#1179
See https://github.com/pytorch/ao/tree/v0.5.0/torchao/prototype/quantized_training#int8-mixed-precision for more details.
For now, I only added the code to show the necessary changes. I'm open to suggestions on how to expose this in torchtune. One idea from mine:
- Add a global config flag
int8_mixed_precision
(similar tocompile
flag). This will be a boolean - Handle it inside
_setup_model()
-> repeated code for each recipe -> UPDATE: from previous feedback, add a new flagmixed_precision
Some concerns:
- It's possible to customize INT8 mixed-precision via
Int8MixedPrecisionTrainingConfig
(see doc). Should we expose it to torchtune's users? From my testing, the default config works well. There might be more knobs to customize in the future too.- UPDATE: expose all options via
Int8MixedPrecisionTrainingQuantizer
- UPDATE: expose all options via
- Ability to extend to other torchao's subclasses? e.g. Float8 and NF4 (right now they don't use
quantize_()
API, though they can be re-implemented to do so).- UPDATE: the better question is how to compose this with QLoRA (i.e. NF4).
LoRALinear
will always callF.linear()
on the NF4 weight. If we make the base weight inLoRALinear
a separatenn.Linear
module (instead of plainnn.Parameter()
, then we can swap the linear module to change the outer op.
- UPDATE: the better question is how to compose this with QLoRA (i.e. NF4).
These concerns can be addressed in the future I think, when torchao's training subclasses become more mature/stable.
Note: I can't test with 4090 since FlexAttention errors out on 4090
triton.runtime.errors.OutOfResources: out of resource: shared memory
It's pretty strange since it works fine for another repo of mine 🤔.
Changelog
What are the changes made in this PR?
Integrate INT8 mixed-precision from torchao ~0.5~ 0.7
Test plan
Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)
- [ ] run pre-commit hooks and linters (make sure you've first installed via
pre-commit install
) - [ ] add unit tests for any new functionality
- [ ] update docstrings for any new or updated methods or classes
- [ ] run unit tests via
pytest tests
- [ ] run recipe tests via
pytest tests -m integration_test
- [x] manually run any new or modified recipes with sufficient proof of correctness
- [ ] include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it. Example of docstring: https://github.com/pytorch/torchtune/blob/6a7951f1cdd0b56a9746ef5935106989415f50e3/torchtune/modules/vision_transformer.py#L285 Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models
mixed_precision._component_=torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer mixed_precision.enabled=True
- [ ] I did not change any public API;
- [ ] I have added an example to docs or docstrings;
Llama3.1-8B single device A100 40% speedup. torch=2.5.0.dev20240911, torchao=0.5.0
tune run full_finetune_single_device --config llama3_1/8B_full_single_device dataset.packed=True tokenizer.max_seq_len=8192 optimizer=torch.optim.AdamW optimizer.fused=True optimizer_in_bwd=False compile=True metric_logger=torchtune.training.metric_logging.WandBLogger log_peak_memory_stats=True
Llama3.1-8B FSDP2 2x A100 24% speedup. torch=2.5.1, pytorch/ao#1179
tune run --nproc_per_node 2 full_finetune_distributed --config llama3_1/8B_full tokenizer.max_seq_len=8192 dataset.packed=True optimizer.fused=True compile=True metric_logger=torchtune.training.metric_logging.WandBLogger log_peak_memory_stats=True
Llama3.1-8B single device A100 LoRA 50% speedup. torch==2.6.0.dev20240914, torchao=0.5.0
tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device tokenizer.max_seq_len=8192 dataset.packed=True compile=True metric_logger=torchtune.training.metric_logging.WandBLogger log_peak_memory_stats=True gradient_accumulation_steps=1
LLama3.2-1B single device 4070Ti SUPER QLoRA 60% speedup. torch==2.6.0.dev20241102+cu124, pytorch/ao#1179. Proof-of-concept only since it requires quite significant changes to LoRALinear
class. See https://github.com/pytorch/torchtune/compare/main...gau-nernst:qlora
tune run lora_finetune_single_device --config llama3_2/1B_qlora_single_device dataset.packed=True tokenizer.max_seq_len=8192 optimizer=torch.optim.AdamW optimizer.fused=True optimizer_in_bwd=False compile=True metric_logger=torchtune.training.metric_logging.WandBLogger log_peak_memory_stats=True batch_size=1 enable_activation_checkpointing=True