litgpt icon indicating copy to clipboard operation
litgpt copied to clipboard

When to stop Training?

Open nikhiljaiswal opened this issue 2 years ago • 2 comments
trafficstars

I am trying to fine-tune LORA. I am unable to understand when training stops and which is the best checkpoint. I can see in hyper-parameters we are setting max_iters = 50000. Does it mean training will always happen for 50000 iterations? How will I select the best checkpoint? How can I be sure that training is comleted?

nikhiljaiswal avatar Jun 22 '23 14:06 nikhiljaiswal

it will save every save_interval and if you sync eval_interval = save_interval, you can select the checkpoint with the lowest loss.

eval_interval = 1000
save_interval = 1000

It would be really nice to enable reporting to wandb to make this a bit easier but you can also add some simple logic to rename the lowest loss checkpoint "best" or something.

griff4692 avatar Jun 22 '23 21:06 griff4692

(Reposting my answer from Discord)

50k iterations with a micro_batch_size of 4 (like in the lora.py script) means that the training will see 200k samples.

Stanford Alpaca contains 52k instruction pairs, so that means you'll go over your examples 4 times (aka 4 epochs). The original Stanford Alpaca authors ran 3 epochs for LLaMA 7B and 5 epochs for LLaMA 13B, so that is a rather safe choice.

As a rule of thumb, you want to leverage your dataset by doing multiple passes, but avoid overfitting it, so running for a few epochs is reasonable.

At the end of the day though, you will want to run evaluations as in https://github.com/EleutherAI/lm-evaluation-harness and see how your model performs on one or more benchmarks of interest. That's the best way to assess when you should stop.

To run the harness on lit-llama, you can run it from this branch https://github.com/Lightning-AI/lm-evaluation-harness/tree/lit-llama

I have a better solution in the works that doesn't force you to use this fork, but I haven't had the time to wrap it up yet.

lantiga avatar Jun 23 '23 08:06 lantiga

How should we scale down the number of iterations if we have 300 datapoints, or let's say x to generalise

RidhiChhajer avatar Jul 15 '24 04:07 RidhiChhajer