torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

[WIP] LayerSkip

Open mostafaelhoushi opened this issue 8 months ago • 7 comments

Context

What is the purpose of this PR? Is it to

  • [X] add a new feature
  • [ ] fix a bug
  • [ ] update tests and/or documentation
  • [ ] other (please add here)

Please link to any issues this PR addresses.

Remaining tasks for this PR:

  • [ ] Interface layer dropout and early exit args to command line
  • [ ] Optimize the way we calculate early exit loss: instead of a loop, we do it in one shot
  • [ ] Add documentation to functions and arguments
  • [ ] Add examples to docs

Usage

TBD, provide command-line args.

Temporary Usage
  1. Checkout this branch
  2. Install the code in dev mode (so that any modifications in the code base will reflect when you run)
cd torchtune
pip install -e ".[dev]"
  1. To enable layer dropout configure those 3 arguments here. For example, LayerSkip Continual Llama2 7B results used: layer_dropout_prob = 0.1, layer_dropout_prob_layer_scale = "exp", layer_dropout_str =":"
  2. To enable early exit loss, ensure output_hidden_states=True when inferring the model [here](https://github.com/mostafaelhoushi/torchtune/blob/ae61c858950186a9956295404f1ae9ea91afb962/recipes/full_finetune_distributed.py#L519
  3. You can run a full-model finetuning experiment and both layer dropout and early exit loss will kick-on:
tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf
tune run --nproc_per_node 8 full_finetune_distributed --config llama2/7B_full batch_size=8

Changelog

  • Modification to TransformerDecoder to support LayerDroput and returning outputs at different layers
  • Modifications to training script to update training loss with early exit losses

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)

  • [ ] run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • [ ] add unit tests for any new functionality
  • [ ] update docstrings for any new or updated methods or classes
  • [ ] run unit tests via pytest tests
  • [ ] run recipe tests via pytest tests -m integration_test
  • [ ] manually run any new or modified recipes with sufficient proof of correctness
    • [ ] include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

mostafaelhoushi avatar Jun 09 '24 04:06 mostafaelhoushi