litgpt icon indicating copy to clipboard operation
litgpt copied to clipboard

Print initial validation loss + final training and validation loss

Open rasbt opened this issue 1 year ago • 1 comments

These PR does 2 things

  1. Users were confused about the initial n/a in the validation loss. One idea I had was that we can actually use the initial validation we already have as an estimate here -- this would not increase the computational cost in any way but already gives some useful information at the beginning of the training.

  2. It's often cumbersome to coordinate the steps and evaluation iterations such that you always get the final validation loss. Actually, a user also complained about not getting any validation loss -- I figured this was because the dataset was too small. In my opinion, we should ALWAYS show the validation loss of the final model. Also, it is helpful to get the training loss of the final model for comparison.

After this PR, the print log look like as follows:

 print-loss ~/litgpt litgpt finetune lora --checkpoint_dir checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T/ --train.max_steps 2 --train.micro_batch_size 1 --lora_r 2
{'checkpoint_dir': PosixPath('checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T'),
 'data': None,
 'devices': 1,
 'eval': EvalArgs(interval=100, max_new_tokens=100, max_iters=100),
 'logger_name': 'csv',
 'lora_alpha': 16,
 'lora_dropout': 0.05,
 'lora_head': False,
 'lora_key': False,
 'lora_mlp': False,
 'lora_projection': False,
 'lora_query': True,
 'lora_r': 2,
 'lora_value': True,
 'out_dir': PosixPath('out/finetune/lora'),
 'precision': None,
 'quantize': None,
 'seed': 1337,
 'train': TrainArgs(save_interval=1000,
                    log_interval=1,
                    global_batch_size=128,
                    micro_batch_size=1,
                    lr_warmup_steps=100,
                    epochs=5,
                    max_tokens=None,
                    max_steps=2,
                    max_seq_length=None,
                    tie_embeddings=None,
                    learning_rate=0.0003,
                    weight_decay=0.02,
                    beta1=0.9,
                    beta2=0.95,
                    max_norm=None,
                    min_lr=6e-05)}
Using bfloat16 Automatic Mixed Precision (AMP)
Seed set to 1337
Number of trainable parameters: 281,600
Number of non-trainable parameters: 1,100,048,384
The longest sequence length in the train data is 1305, the model's maximum sequence length is 1305 and context length is 2048
Validating ...
Recommend a movie for me to watch during the weekend and explain the reason.
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Recommend a movie for me to watch during the weekend and explain the reason.

### Response:
I would really like to watch Batman Begins. It is a must see for anyone who grew up in the 80s. It is a great mix of action and drama. If you like action movies, you should watch this one.

### Instruction:
Recommend a book that was written 100 years ago to read this weekend.

### Response:
The Rise of Skywalker by <NAME> is a great
Epoch 1 | iter 1 step 0 | loss train: 2.057, val: 2.464 | iter time: 270.94 ms
Epoch 1 | iter 2 step 0 | loss train: 2.032, val: 2.464 | iter time: 108.13 ms
Epoch 1 | iter 3 step 0 | loss train: 2.097, val: 2.464 | iter time: 104.56 ms
Epoch 1 | iter 4 step 0 | loss train: 2.221, val: 2.464 | iter time: 126.57 ms
Epoch 1 | iter 5 step 0 | loss train: 2.152, val: 2.464 | iter time: 99.59 ms
Epoch 1 | iter 6 step 0 | loss train: 2.058, val: 2.464 | iter time: 101.55 ms
Epoch 1 | iter 7 step 0 | loss train: 2.042, val: 2.464 | iter time: 99.69 ms
Epoch 1 | iter 8 step 0 | loss train: 2.069, val: 2.464 | iter time: 101.18 ms
Epoch 1 | iter 9 step 0 | loss train: 2.026, val: 2.464 | iter time: 104.00 ms
Epoch 1 | iter 10 step 0 | loss train: 2.045, val: 2.464 | iter time: 104.55 ms
Epoch 1 | iter 11 step 0 | loss train: 2.072, val: 2.464 | iter time: 102.42 ms
Epoch 1 | iter 12 step 0 | loss train: 2.157, val: 2.464 | iter time: 101.26 ms
Epoch 1 | iter 13 step 0 | loss train: 2.172, val: 2.464 | iter time: 100.42 ms
Epoch 1 | iter 14 step 0 | loss train: 2.155, val: 2.464 | iter time: 101.57 ms
Epoch 1 | iter 15 step 0 | loss train: 2.183, val: 2.464 | iter time: 100.80 ms
Epoch 1 | iter 16 step 0 | loss train: 2.219, val: 2.464 | iter time: 99.80 ms
Epoch 1 | iter 17 step 0 | loss train: 2.182, val: 2.464 | iter time: 99.58 ms
Epoch 1 | iter 18 step 0 | loss train: 2.171, val: 2.464 | iter time: 89.90 ms
Epoch 1 | iter 19 step 0 | loss train: 2.182, val: 2.464 | iter time: 104.30 ms
Epoch 1 | iter 20 step 0 | loss train: 2.182, val: 2.464 | iter time: 103.20 ms
Epoch 1 | iter 21 step 0 | loss train: 2.196, val: 2.464 | iter time: 101.88 ms
Epoch 1 | iter 22 step 0 | loss train: 2.223, val: 2.464 | iter time: 99.99 ms
Epoch 1 | iter 23 step 0 | loss train: 2.209, val: 2.464 | iter time: 90.51 ms
Epoch 1 | iter 24 step 0 | loss train: 2.214, val: 2.464 | iter time: 101.98 ms
Epoch 1 | iter 25 step 0 | loss train: 2.193, val: 2.464 | iter time: 101.25 ms
Epoch 1 | iter 26 step 0 | loss train: 2.192, val: 2.464 | iter time: 101.93 ms
Epoch 1 | iter 27 step 0 | loss train: 2.170, val: 2.464 | iter time: 109.82 ms
Epoch 1 | iter 28 step 0 | loss train: 2.179, val: 2.464 | iter time: 100.85 ms
Epoch 1 | iter 29 step 0 | loss train: 2.188, val: 2.464 | iter time: 102.88 ms
Epoch 1 | iter 30 step 0 | loss train: 2.186, val: 2.464 | iter time: 101.77 ms
Epoch 1 | iter 31 step 0 | loss train: 2.194, val: 2.464 | iter time: 89.56 ms
Epoch 1 | iter 32 step 0 | loss train: 2.214, val: 2.464 | iter time: 105.83 ms
Epoch 1 | iter 33 step 0 | loss train: 2.224, val: 2.464 | iter time: 101.59 ms
Epoch 1 | iter 34 step 0 | loss train: 2.238, val: 2.464 | iter time: 89.51 ms
Epoch 1 | iter 35 step 0 | loss train: 2.221, val: 2.464 | iter time: 102.50 ms
Epoch 1 | iter 36 step 0 | loss train: 2.217, val: 2.464 | iter time: 99.60 ms
Epoch 1 | iter 37 step 0 | loss train: 2.222, val: 2.464 | iter time: 90.40 ms
Epoch 1 | iter 38 step 0 | loss train: 2.219, val: 2.464 | iter time: 102.78 ms
Epoch 1 | iter 39 step 0 | loss train: 2.224, val: 2.464 | iter time: 91.61 ms
Epoch 1 | iter 40 step 0 | loss train: 2.235, val: 2.464 | iter time: 90.30 ms
Epoch 1 | iter 41 step 0 | loss train: 2.239, val: 2.464 | iter time: 101.80 ms
Epoch 1 | iter 42 step 0 | loss train: 2.248, val: 2.464 | iter time: 100.67 ms
Epoch 1 | iter 43 step 0 | loss train: 2.253, val: 2.464 | iter time: 89.85 ms
Epoch 1 | iter 44 step 0 | loss train: 2.242, val: 2.464 | iter time: 101.50 ms
Epoch 1 | iter 45 step 0 | loss train: 2.259, val: 2.464 | iter time: 100.87 ms
Epoch 1 | iter 46 step 0 | loss train: 2.239, val: 2.464 | iter time: 103.15 ms
Epoch 1 | iter 47 step 0 | loss train: 2.257, val: 2.464 | iter time: 100.38 ms
Epoch 1 | iter 48 step 0 | loss train: 2.263, val: 2.464 | iter time: 101.91 ms
Epoch 1 | iter 49 step 0 | loss train: 2.279, val: 2.464 | iter time: 103.66 ms
Epoch 1 | iter 50 step 0 | loss train: 2.286, val: 2.464 | iter time: 90.12 ms
Epoch 1 | iter 51 step 0 | loss train: 2.289, val: 2.464 | iter time: 101.39 ms
Epoch 1 | iter 52 step 0 | loss train: 2.279, val: 2.464 | iter time: 89.96 ms
Epoch 1 | iter 53 step 0 | loss train: 2.281, val: 2.464 | iter time: 101.95 ms
Epoch 1 | iter 54 step 0 | loss train: 2.295, val: 2.464 | iter time: 90.28 ms
Epoch 1 | iter 55 step 0 | loss train: 2.297, val: 2.464 | iter time: 89.81 ms
Epoch 1 | iter 56 step 0 | loss train: 2.302, val: 2.464 | iter time: 97.59 ms
Epoch 1 | iter 57 step 0 | loss train: 2.311, val: 2.464 | iter time: 100.48 ms
Epoch 1 | iter 58 step 0 | loss train: 2.303, val: 2.464 | iter time: 104.22 ms
Epoch 1 | iter 59 step 0 | loss train: 2.292, val: 2.464 | iter time: 120.04 ms
Epoch 1 | iter 60 step 0 | loss train: 2.294, val: 2.464 | iter time: 91.77 ms
Epoch 1 | iter 61 step 0 | loss train: 2.294, val: 2.464 | iter time: 102.26 ms
Epoch 1 | iter 62 step 0 | loss train: 2.287, val: 2.464 | iter time: 90.75 ms
Epoch 1 | iter 63 step 0 | loss train: 2.289, val: 2.464 | iter time: 102.56 ms
Epoch 1 | iter 64 step 0 | loss train: 2.296, val: 2.464 | iter time: 102.94 ms
Epoch 1 | iter 65 step 0 | loss train: 2.301, val: 2.464 | iter time: 103.33 ms
Epoch 1 | iter 66 step 0 | loss train: 2.300, val: 2.464 | iter time: 102.78 ms
Epoch 1 | iter 67 step 0 | loss train: 2.297, val: 2.464 | iter time: 103.30 ms
Epoch 1 | iter 68 step 0 | loss train: 2.298, val: 2.464 | iter time: 102.43 ms
Epoch 1 | iter 69 step 0 | loss train: 2.300, val: 2.464 | iter time: 102.99 ms
Epoch 1 | iter 70 step 0 | loss train: 2.303, val: 2.464 | iter time: 89.78 ms
Epoch 1 | iter 71 step 0 | loss train: 2.303, val: 2.464 | iter time: 89.94 ms
Epoch 1 | iter 72 step 0 | loss train: 2.307, val: 2.464 | iter time: 90.08 ms
Epoch 1 | iter 73 step 0 | loss train: 2.314, val: 2.464 | iter time: 101.53 ms
Epoch 1 | iter 74 step 0 | loss train: 2.307, val: 2.464 | iter time: 105.84 ms
Epoch 1 | iter 75 step 0 | loss train: 2.300, val: 2.464 | iter time: 103.55 ms
Epoch 1 | iter 76 step 0 | loss train: 2.294, val: 2.464 | iter time: 102.02 ms
Epoch 1 | iter 77 step 0 | loss train: 2.294, val: 2.464 | iter time: 90.55 ms
Epoch 1 | iter 78 step 0 | loss train: 2.297, val: 2.464 | iter time: 102.09 ms
Epoch 1 | iter 79 step 0 | loss train: 2.293, val: 2.464 | iter time: 104.77 ms
Epoch 1 | iter 80 step 0 | loss train: 2.303, val: 2.464 | iter time: 90.36 ms
Epoch 1 | iter 81 step 0 | loss train: 2.305, val: 2.464 | iter time: 89.72 ms
Epoch 1 | iter 82 step 0 | loss train: 2.299, val: 2.464 | iter time: 105.92 ms
Epoch 1 | iter 83 step 0 | loss train: 2.298, val: 2.464 | iter time: 104.86 ms
Epoch 1 | iter 84 step 0 | loss train: 2.296, val: 2.464 | iter time: 102.27 ms
Epoch 1 | iter 85 step 0 | loss train: 2.306, val: 2.464 | iter time: 90.54 ms
Epoch 1 | iter 86 step 0 | loss train: 2.310, val: 2.464 | iter time: 91.34 ms
Epoch 1 | iter 87 step 0 | loss train: 2.308, val: 2.464 | iter time: 105.79 ms
Epoch 1 | iter 88 step 0 | loss train: 2.316, val: 2.464 | iter time: 103.05 ms
Epoch 1 | iter 89 step 0 | loss train: 2.321, val: 2.464 | iter time: 104.71 ms
Epoch 1 | iter 90 step 0 | loss train: 2.321, val: 2.464 | iter time: 104.86 ms
Epoch 1 | iter 91 step 0 | loss train: 2.326, val: 2.464 | iter time: 91.43 ms
Epoch 1 | iter 92 step 0 | loss train: 2.328, val: 2.464 | iter time: 101.99 ms
Epoch 1 | iter 93 step 0 | loss train: 2.328, val: 2.464 | iter time: 98.61 ms
Epoch 1 | iter 94 step 0 | loss train: 2.330, val: 2.464 | iter time: 101.18 ms
Epoch 1 | iter 95 step 0 | loss train: 2.336, val: 2.464 | iter time: 91.27 ms
Epoch 1 | iter 96 step 0 | loss train: 2.328, val: 2.464 | iter time: 103.63 ms
Epoch 1 | iter 97 step 0 | loss train: 2.331, val: 2.464 | iter time: 90.52 ms
Epoch 1 | iter 98 step 0 | loss train: 2.334, val: 2.464 | iter time: 90.54 ms
Epoch 1 | iter 99 step 0 | loss train: 2.335, val: 2.464 | iter time: 104.65 ms
Epoch 1 | iter 100 step 0 | loss train: 2.335, val: 2.464 | iter time: 104.30 ms
Epoch 1 | iter 101 step 0 | loss train: 2.335, val: 2.464 | iter time: 103.40 ms
Epoch 1 | iter 102 step 0 | loss train: 2.336, val: 2.464 | iter time: 90.42 ms
Epoch 1 | iter 103 step 0 | loss train: 2.329, val: 2.464 | iter time: 102.13 ms
Epoch 1 | iter 104 step 0 | loss train: 2.332, val: 2.464 | iter time: 91.84 ms
Epoch 1 | iter 105 step 0 | loss train: 2.330, val: 2.464 | iter time: 90.89 ms
Epoch 1 | iter 106 step 0 | loss train: 2.330, val: 2.464 | iter time: 91.04 ms
Epoch 1 | iter 107 step 0 | loss train: 2.326, val: 2.464 | iter time: 106.07 ms
Epoch 1 | iter 108 step 0 | loss train: 2.318, val: 2.464 | iter time: 104.12 ms
Epoch 1 | iter 109 step 0 | loss train: 2.322, val: 2.464 | iter time: 92.15 ms
Epoch 1 | iter 110 step 0 | loss train: 2.319, val: 2.464 | iter time: 92.47 ms
Epoch 1 | iter 111 step 0 | loss train: 2.321, val: 2.464 | iter time: 91.57 ms
Epoch 1 | iter 112 step 0 | loss train: 2.320, val: 2.464 | iter time: 103.02 ms
Epoch 1 | iter 113 step 0 | loss train: 2.314, val: 2.464 | iter time: 107.78 ms
Epoch 1 | iter 114 step 0 | loss train: 2.318, val: 2.464 | iter time: 104.18 ms
Epoch 1 | iter 115 step 0 | loss train: 2.316, val: 2.464 | iter time: 102.16 ms
Epoch 1 | iter 116 step 0 | loss train: 2.313, val: 2.464 | iter time: 92.56 ms
Epoch 1 | iter 117 step 0 | loss train: 2.309, val: 2.464 | iter time: 91.32 ms
Epoch 1 | iter 118 step 0 | loss train: 2.306, val: 2.464 | iter time: 89.53 ms
Epoch 1 | iter 119 step 0 | loss train: 2.306, val: 2.464 | iter time: 90.91 ms
Epoch 1 | iter 120 step 0 | loss train: 2.310, val: 2.464 | iter time: 92.55 ms
Epoch 1 | iter 121 step 0 | loss train: 2.308, val: 2.464 | iter time: 91.95 ms
Epoch 1 | iter 122 step 0 | loss train: 2.306, val: 2.464 | iter time: 89.67 ms
Epoch 1 | iter 123 step 0 | loss train: 2.304, val: 2.464 | iter time: 91.57 ms
Epoch 1 | iter 124 step 0 | loss train: 2.310, val: 2.464 | iter time: 102.71 ms
Epoch 1 | iter 125 step 0 | loss train: 2.307, val: 2.464 | iter time: 90.92 ms
Epoch 1 | iter 126 step 0 | loss train: 2.308, val: 2.464 | iter time: 90.76 ms
Epoch 1 | iter 127 step 0 | loss train: 2.310, val: 2.464 | iter time: 102.19 ms
Epoch 1 | iter 128 step 1 | loss train: 2.306, val: 2.464 | iter time: 143.68 ms (step)
Epoch 1 | iter 129 step 1 | loss train: 2.309, val: 2.464 | iter time: 102.03 ms
Epoch 1 | iter 130 step 1 | loss train: 2.310, val: 2.464 | iter time: 93.11 ms
Epoch 1 | iter 131 step 1 | loss train: 2.309, val: 2.464 | iter time: 104.92 ms
Epoch 1 | iter 132 step 1 | loss train: 2.308, val: 2.464 | iter time: 92.36 ms
Epoch 1 | iter 133 step 1 | loss train: 2.311, val: 2.464 | iter time: 94.21 ms
Epoch 1 | iter 134 step 1 | loss train: 2.315, val: 2.464 | iter time: 90.95 ms
Epoch 1 | iter 135 step 1 | loss train: 2.319, val: 2.464 | iter time: 103.61 ms
Epoch 1 | iter 136 step 1 | loss train: 2.322, val: 2.464 | iter time: 91.31 ms
Epoch 1 | iter 137 step 1 | loss train: 2.328, val: 2.464 | iter time: 101.83 ms
Epoch 1 | iter 138 step 1 | loss train: 2.332, val: 2.464 | iter time: 102.48 ms
Epoch 1 | iter 139 step 1 | loss train: 2.331, val: 2.464 | iter time: 101.45 ms
Epoch 1 | iter 140 step 1 | loss train: 2.321, val: 2.464 | iter time: 105.39 ms
Epoch 1 | iter 141 step 1 | loss train: 2.312, val: 2.464 | iter time: 103.86 ms
Epoch 1 | iter 142 step 1 | loss train: 2.315, val: 2.464 | iter time: 91.18 ms
Epoch 1 | iter 143 step 1 | loss train: 2.310, val: 2.464 | iter time: 91.04 ms
Epoch 1 | iter 144 step 1 | loss train: 2.307, val: 2.464 | iter time: 91.14 ms
Epoch 1 | iter 145 step 1 | loss train: 2.311, val: 2.464 | iter time: 103.97 ms
Epoch 1 | iter 146 step 1 | loss train: 2.306, val: 2.464 | iter time: 92.95 ms
Epoch 1 | iter 147 step 1 | loss train: 2.302, val: 2.464 | iter time: 107.66 ms
Epoch 1 | iter 148 step 1 | loss train: 2.301, val: 2.464 | iter time: 91.53 ms
Epoch 1 | iter 149 step 1 | loss train: 2.297, val: 2.464 | iter time: 104.43 ms
Epoch 1 | iter 150 step 1 | loss train: 2.298, val: 2.464 | iter time: 92.85 ms
Epoch 1 | iter 151 step 1 | loss train: 2.293, val: 2.464 | iter time: 104.36 ms
Epoch 1 | iter 152 step 1 | loss train: 2.292, val: 2.464 | iter time: 103.11 ms
Epoch 1 | iter 153 step 1 | loss train: 2.301, val: 2.464 | iter time: 91.51 ms
Epoch 1 | iter 154 step 1 | loss train: 2.300, val: 2.464 | iter time: 90.85 ms
Epoch 1 | iter 155 step 1 | loss train: 2.304, val: 2.464 | iter time: 104.08 ms
Epoch 1 | iter 156 step 1 | loss train: 2.304, val: 2.464 | iter time: 102.58 ms
Epoch 1 | iter 157 step 1 | loss train: 2.300, val: 2.464 | iter time: 89.68 ms
Epoch 1 | iter 158 step 1 | loss train: 2.302, val: 2.464 | iter time: 90.89 ms
Epoch 1 | iter 159 step 1 | loss train: 2.302, val: 2.464 | iter time: 91.06 ms
Epoch 1 | iter 160 step 1 | loss train: 2.300, val: 2.464 | iter time: 92.33 ms
Epoch 1 | iter 161 step 1 | loss train: 2.297, val: 2.464 | iter time: 91.73 ms
Epoch 1 | iter 162 step 1 | loss train: 2.294, val: 2.464 | iter time: 91.01 ms
Epoch 1 | iter 163 step 1 | loss train: 2.294, val: 2.464 | iter time: 103.67 ms
Epoch 1 | iter 164 step 1 | loss train: 2.301, val: 2.464 | iter time: 102.94 ms
Epoch 1 | iter 165 step 1 | loss train: 2.304, val: 2.464 | iter time: 91.37 ms
Epoch 1 | iter 166 step 1 | loss train: 2.308, val: 2.464 | iter time: 90.74 ms
Epoch 1 | iter 167 step 1 | loss train: 2.309, val: 2.464 | iter time: 90.86 ms
Epoch 1 | iter 168 step 1 | loss train: 2.307, val: 2.464 | iter time: 91.90 ms
Epoch 1 | iter 169 step 1 | loss train: 2.309, val: 2.464 | iter time: 90.00 ms
Epoch 1 | iter 170 step 1 | loss train: 2.311, val: 2.464 | iter time: 90.57 ms
Epoch 1 | iter 171 step 1 | loss train: 2.308, val: 2.464 | iter time: 90.85 ms
Epoch 1 | iter 172 step 1 | loss train: 2.316, val: 2.464 | iter time: 102.28 ms
Epoch 1 | iter 173 step 1 | loss train: 2.316, val: 2.464 | iter time: 88.96 ms
Epoch 1 | iter 174 step 1 | loss train: 2.327, val: 2.464 | iter time: 89.02 ms
Epoch 1 | iter 175 step 1 | loss train: 2.323, val: 2.464 | iter time: 89.12 ms
Epoch 1 | iter 176 step 1 | loss train: 2.326, val: 2.464 | iter time: 88.99 ms
Epoch 1 | iter 177 step 1 | loss train: 2.317, val: 2.464 | iter time: 102.23 ms
Epoch 1 | iter 178 step 1 | loss train: 2.311, val: 2.464 | iter time: 103.32 ms
Epoch 1 | iter 179 step 1 | loss train: 2.308, val: 2.464 | iter time: 103.10 ms
Epoch 1 | iter 180 step 1 | loss train: 2.314, val: 2.464 | iter time: 88.74 ms
Epoch 1 | iter 181 step 1 | loss train: 2.312, val: 2.464 | iter time: 102.72 ms
Epoch 1 | iter 182 step 1 | loss train: 2.306, val: 2.464 | iter time: 88.60 ms
Epoch 1 | iter 183 step 1 | loss train: 2.305, val: 2.464 | iter time: 89.91 ms
Epoch 1 | iter 184 step 1 | loss train: 2.305, val: 2.464 | iter time: 96.02 ms
Epoch 1 | iter 185 step 1 | loss train: 2.305, val: 2.464 | iter time: 90.44 ms
Epoch 1 | iter 186 step 1 | loss train: 2.306, val: 2.464 | iter time: 90.07 ms
Epoch 1 | iter 187 step 1 | loss train: 2.315, val: 2.464 | iter time: 90.20 ms
Epoch 1 | iter 188 step 1 | loss train: 2.308, val: 2.464 | iter time: 89.53 ms
Epoch 1 | iter 189 step 1 | loss train: 2.310, val: 2.464 | iter time: 101.04 ms
Epoch 1 | iter 190 step 1 | loss train: 2.311, val: 2.464 | iter time: 89.72 ms
Epoch 1 | iter 191 step 1 | loss train: 2.308, val: 2.464 | iter time: 89.71 ms
Epoch 1 | iter 192 step 1 | loss train: 2.303, val: 2.464 | iter time: 92.52 ms
Epoch 1 | iter 193 step 1 | loss train: 2.300, val: 2.464 | iter time: 89.33 ms
Epoch 1 | iter 194 step 1 | loss train: 2.300, val: 2.464 | iter time: 89.90 ms
Epoch 1 | iter 195 step 1 | loss train: 2.303, val: 2.464 | iter time: 90.44 ms
Epoch 1 | iter 196 step 1 | loss train: 2.306, val: 2.464 | iter time: 90.29 ms
Epoch 1 | iter 197 step 1 | loss train: 2.305, val: 2.464 | iter time: 88.91 ms
Epoch 1 | iter 198 step 1 | loss train: 2.308, val: 2.464 | iter time: 90.03 ms
Epoch 1 | iter 199 step 1 | loss train: 2.305, val: 2.464 | iter time: 90.99 ms
Epoch 1 | iter 200 step 1 | loss train: 2.304, val: 2.464 | iter time: 89.13 ms
Epoch 1 | iter 201 step 1 | loss train: 2.297, val: 2.464 | iter time: 91.02 ms
Epoch 1 | iter 202 step 1 | loss train: 2.296, val: 2.464 | iter time: 104.24 ms
Epoch 1 | iter 203 step 1 | loss train: 2.300, val: 2.464 | iter time: 91.85 ms
Epoch 1 | iter 204 step 1 | loss train: 2.301, val: 2.464 | iter time: 103.15 ms
Epoch 1 | iter 205 step 1 | loss train: 2.300, val: 2.464 | iter time: 89.47 ms
Epoch 1 | iter 206 step 1 | loss train: 2.299, val: 2.464 | iter time: 89.13 ms
Epoch 1 | iter 207 step 1 | loss train: 2.301, val: 2.464 | iter time: 88.84 ms
Epoch 1 | iter 208 step 1 | loss train: 2.290, val: 2.464 | iter time: 89.50 ms
Epoch 1 | iter 209 step 1 | loss train: 2.296, val: 2.464 | iter time: 88.86 ms
Epoch 1 | iter 210 step 1 | loss train: 2.300, val: 2.464 | iter time: 102.57 ms
Epoch 1 | iter 211 step 1 | loss train: 2.294, val: 2.464 | iter time: 106.63 ms
Epoch 1 | iter 212 step 1 | loss train: 2.291, val: 2.464 | iter time: 90.20 ms
Epoch 1 | iter 213 step 1 | loss train: 2.281, val: 2.464 | iter time: 111.68 ms
Epoch 1 | iter 214 step 1 | loss train: 2.277, val: 2.464 | iter time: 89.78 ms
Epoch 1 | iter 215 step 1 | loss train: 2.280, val: 2.464 | iter time: 89.49 ms
Epoch 1 | iter 216 step 1 | loss train: 2.272, val: 2.464 | iter time: 102.96 ms
Epoch 1 | iter 217 step 1 | loss train: 2.267, val: 2.464 | iter time: 103.21 ms
Epoch 1 | iter 218 step 1 | loss train: 2.264, val: 2.464 | iter time: 103.32 ms
Epoch 1 | iter 219 step 1 | loss train: 2.258, val: 2.464 | iter time: 88.71 ms
Epoch 1 | iter 220 step 1 | loss train: 2.256, val: 2.464 | iter time: 89.76 ms
Epoch 1 | iter 221 step 1 | loss train: 2.252, val: 2.464 | iter time: 89.83 ms
Epoch 1 | iter 222 step 1 | loss train: 2.248, val: 2.464 | iter time: 89.71 ms
Epoch 1 | iter 223 step 1 | loss train: 2.239, val: 2.464 | iter time: 102.95 ms
Epoch 1 | iter 224 step 1 | loss train: 2.246, val: 2.464 | iter time: 103.77 ms
Epoch 1 | iter 225 step 1 | loss train: 2.247, val: 2.464 | iter time: 89.97 ms
Epoch 1 | iter 226 step 1 | loss train: 2.247, val: 2.464 | iter time: 88.57 ms
Epoch 1 | iter 227 step 1 | loss train: 2.243, val: 2.464 | iter time: 90.63 ms
Epoch 1 | iter 228 step 1 | loss train: 2.247, val: 2.464 | iter time: 88.83 ms
Epoch 1 | iter 229 step 1 | loss train: 2.247, val: 2.464 | iter time: 90.66 ms
Epoch 1 | iter 230 step 1 | loss train: 2.241, val: 2.464 | iter time: 90.47 ms
Epoch 1 | iter 231 step 1 | loss train: 2.248, val: 2.464 | iter time: 89.11 ms
Epoch 1 | iter 232 step 1 | loss train: 2.248, val: 2.464 | iter time: 89.39 ms
Epoch 1 | iter 233 step 1 | loss train: 2.250, val: 2.464 | iter time: 89.21 ms
Epoch 1 | iter 234 step 1 | loss train: 2.244, val: 2.464 | iter time: 91.94 ms
Epoch 1 | iter 235 step 1 | loss train: 2.252, val: 2.464 | iter time: 88.95 ms
Epoch 1 | iter 236 step 1 | loss train: 2.260, val: 2.464 | iter time: 89.50 ms
Epoch 1 | iter 237 step 1 | loss train: 2.257, val: 2.464 | iter time: 89.82 ms
Epoch 1 | iter 238 step 1 | loss train: 2.253, val: 2.464 | iter time: 90.37 ms
Epoch 1 | iter 239 step 1 | loss train: 2.254, val: 2.464 | iter time: 89.39 ms
Epoch 1 | iter 240 step 1 | loss train: 2.257, val: 2.464 | iter time: 106.86 ms
Epoch 1 | iter 241 step 1 | loss train: 2.265, val: 2.464 | iter time: 89.31 ms
Epoch 1 | iter 242 step 1 | loss train: 2.261, val: 2.464 | iter time: 88.49 ms
Epoch 1 | iter 243 step 1 | loss train: 2.260, val: 2.464 | iter time: 89.67 ms
Epoch 1 | iter 244 step 1 | loss train: 2.265, val: 2.464 | iter time: 88.75 ms
Epoch 1 | iter 245 step 1 | loss train: 2.270, val: 2.464 | iter time: 103.20 ms
Epoch 1 | iter 246 step 1 | loss train: 2.273, val: 2.464 | iter time: 101.92 ms
Epoch 1 | iter 247 step 1 | loss train: 2.270, val: 2.464 | iter time: 91.03 ms
Epoch 1 | iter 248 step 1 | loss train: 2.265, val: 2.464 | iter time: 88.42 ms
Epoch 1 | iter 249 step 1 | loss train: 2.264, val: 2.464 | iter time: 103.63 ms
Epoch 1 | iter 250 step 1 | loss train: 2.266, val: 2.464 | iter time: 89.41 ms
Epoch 1 | iter 251 step 1 | loss train: 2.273, val: 2.464 | iter time: 88.58 ms
Epoch 1 | iter 252 step 1 | loss train: 2.262, val: 2.464 | iter time: 90.99 ms
Epoch 1 | iter 253 step 1 | loss train: 2.258, val: 2.464 | iter time: 91.08 ms
Epoch 1 | iter 254 step 1 | loss train: 2.257, val: 2.464 | iter time: 89.35 ms
Epoch 1 | iter 255 step 1 | loss train: 2.257, val: 2.464 | iter time: 89.33 ms
Epoch 1 | iter 256 step 2 | loss train: 2.256, val: 2.464 | iter time: 92.05 ms (step)
Final validation ...
Final train loss: 2.181 | final val loss: 2.197
Training time: 58.18s
Memory used: 7.85 GB
Saving LoRA weights to 'out/finetune/lora/final/lit_model.pth.lora'
LoRA weights have already been merged in this checkpoint.

Note the new part at the bottom:

Final validation ...
Final train loss: 2.181 | final val loss: 2.197

Fixes #1221

rasbt avatar Apr 01 '24 17:04 rasbt

I greatly simplified this @carmocca . Let me know if that's ok now. It should introduce any changes to the defaults.

rasbt avatar Apr 25 '24 21:04 rasbt

Is it really needed to add an argument for this? Why not just use the initial sanity check as the initial validation loss?

I originally had that but others found it was misleading because it's calculated only on 2 batches and is not reliable. I then replaced it with the full loss calculation but then this was too slow. So the new argument is a compromise so that I can enable it for me but it doesn't slow it down for everyone.

rasbt avatar Apr 26 '24 11:04 rasbt