torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

QLoRA with Llama 3.1 405B

Open pbontrager opened this issue 7 months ago • 7 comments

Context

What is the purpose of this PR? Is it to

  • [x] add a new feature
  • [ ] fix a bug
  • [ ] update tests and/or documentation
  • [ ] other (please add here)

This PR adds a Lllama 405B QLoRA config for the lora_finetune_fsdp2 recipe. This config requires 8 x A100s to run. The config successfully finetunes Llama 405B on the alpaca dataset as an example. There are several caveats for getting this large model to fit on a single node:

  • Model checkpointing causes NCCL timeout errors, so we only save the adapter weights that can be merged afterward
  • Training is slow, at around 10 minutes for 16 gradient steps (this should improve with compile support)
  • Can only fit < 4k context length

This PR was jointly made with me and @joecummings

Changelog

  • New config
  • New model builders
  • Updated Docs for new model builders
  • Updated fsdp2 recipe save checkpoint to support only saving adapter weights

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)

  • [x] run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • [ ] add unit tests for any new functionality
  • [x] update docstrings for any new or updated methods or classes
  • [x] run unit tests via pytest tests
  • [x] run recipe tests via pytest tests -m integration_test
  • [x] manually run any new or modified recipes with sufficient proof of correctness
    • [x] include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

Test Memory + Checkpointing tune run --nproc_per_node 8 lora_finetune_fsdp2 --config llama3_1/405B_qlora log_peak_memory_stats=True dataset.packed=True dataset.max_seq_len=2048 max_steps_per_epoch=2 epochs=2

Test Model Loss tune run --nproc_per_node 8 lora_finetune_fsdp2 --config llama3_1/405B_qlora

pbontrager avatar Jul 26 '24 14:07 pbontrager