torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

Default to llama3-8b-instruct

Open ebsmothers opened this issue 2 years ago • 4 comments

There are some gotchas around usage of the base llama3 fine-tuned models with respect to special tokens. While we should smooth these out and make it easy to use for e.g. LoRA fine-tuning, given that most of our fine-tuning loops focus on either instruction or chat tasks I think it makes sense to have the instruct-tuned version as our default. (We already have the proper formatting for this in our tokenizer anyways)

Test plan:

Run each of the commands in the recipes manually to make sure the instruct model works as expected:

tune run lora_finetune_single_device --config llama3/8B_lora_single_device
...
1|74|Loss: 2.483799695968628:   0%|▎ 
tune run lora_finetune_single_device --config llama3/8B_qlora_single_device
...
1|9|Loss: 1.8418922424316406:   0%|  
tune run full_finetune_single_device --config llama3/8B_full_single_device
...
1|10|Loss: 1.280301809310913:   0%
tune run --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full
...
1|43|Loss: 1.0445821285247803:   0%|▎
tune run --nproc_per_node 2 lora_finetune_distributed --config llama3/8B_lora
...

The final command gives the following loss curve:

Screenshot 2024-04-23 at 5 03 40 PM

Running eval on the checkpoint from this run using the truthfulqa_mc2 task gives accuracy 53.6%, as opposed to 51.7% from the original Llama3-8B-Instruct checkpoint.

ebsmothers avatar Apr 23 '24 21:04 ebsmothers

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/851

Note: Links to docs will display an error until the docs builds have been completed.

:white_check_mark: No Failures

As of commit 525464056d674bf1978b1ff84ccbe70975836978 with merge base dd99f379df95a0fe5b7ff875a3116efe304e7b8a (image): :green_heart: Looks good so far! There are no failures yet. :green_heart:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Apr 23 '24 21:04 pytorch-bot[bot]

Are the hyperparams similar to llama-2 instruct model's training? Otherwise, we can maybe also change some default hyperparams too? such as LR. I see its set as 2e-5 for now, but not sure if it's just copy pasted from llama2.

musabgultekin avatar Apr 24 '24 06:04 musabgultekin

Are the hyperparams similar to llama-2 instruct model's training? Otherwise, we can maybe also change some default hyperparams too? such as LR. I see its set as 2e-5 for now, but not sure if it's just copy pasted from llama2.

Anecdotally, LR should be set to something a little lower.

joecummings avatar Apr 24 '24 12:04 joecummings

@musabgultekin good question. I haven't run comprehensive experiments to find the best LR values for all these configs yet, I have some experiments for the distributed LoRA recipe inflight now. If you'd like to help out with this or any other experiments we'd be happy to tune our configs based on your findings.

ebsmothers avatar Apr 24 '24 13:04 ebsmothers

Quick update on this: I ran our default Llama-3-8B-Instruct config for LoRA with a few different learning rates. Pasting the loss curves below (only about 80% done training right now):

Screenshot 2024-04-24 at 9 09 32 AM

So there may still be some tuning left to do, but based on this I claim that 3e-4 is at least a pretty reasonable choice of LR for this config. But again if you run other experiments and find different please let me know!

ebsmothers avatar Apr 24 '24 16:04 ebsmothers

I could probably do w&b sweep. For sure will let you know when I have some results

musabgultekin avatar Apr 24 '24 16:04 musabgultekin