torchtune
torchtune copied to clipboard
QLoRA with Llama 3.1 405B
Context
What is the purpose of this PR? Is it to
- [x] add a new feature
- [ ] fix a bug
- [ ] update tests and/or documentation
- [ ] other (please add here)
This PR adds a Lllama 405B QLoRA config for the lora_finetune_fsdp2 recipe. This config requires 8 x A100s to run. The config successfully finetunes Llama 405B on the alpaca dataset as an example. There are several caveats for getting this large model to fit on a single node:
- Model checkpointing causes NCCL timeout errors, so we only save the adapter weights that can be merged afterward
- Training is slow, at around 10 minutes for 16 gradient steps (this should improve with compile support)
- Can only fit < 4k context length
This PR was jointly made with me and @joecummings
Changelog
- New config
- New model builders
- Updated Docs for new model builders
- Updated fsdp2 recipe save checkpoint to support only saving adapter weights
Test plan
Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)
- [x] run pre-commit hooks and linters (make sure you've first installed via
pre-commit install
) - [ ] add unit tests for any new functionality
- [x] update docstrings for any new or updated methods or classes
- [x] run unit tests via
pytest tests
- [x] run recipe tests via
pytest tests -m integration_test
- [x] manually run any new or modified recipes with sufficient proof of correctness
- [x] include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)
Test Memory + Checkpointing
tune run --nproc_per_node 8 lora_finetune_fsdp2 --config llama3_1/405B_qlora log_peak_memory_stats=True dataset.packed=True dataset.max_seq_len=2048 max_steps_per_epoch=2 epochs=2
Test Model Loss
tune run --nproc_per_node 8 lora_finetune_fsdp2 --config llama3_1/405B_qlora