litgpt icon indicating copy to clipboard operation
litgpt copied to clipboard

Unable to `finetune/lora.py` with `DDP`

Open v-dicicco opened this issue 2 years ago • 3 comments

Hello, I'm trying different distributed training strategies changing Fabric's strategy argument to different values (as listed here).

To sanity check, I'm verifying the training is able to overfit a small model (TinyLlama) with a really toy setting: a dataset of just 20 Alpaca samples, batch_size=6, micro_batch_size=2 and two devices (GPUs).

Using the default FSDP strategy you can clearly see the model is training, quickly reaching small loss values:

full.py with FSDP (default): train loss 0.0145 after 24 steps
iter 1 step 0: loss 3.5551, iter time: 832.90ms iter 2 step 1: loss 3.8844, iter time: 495.84ms (optimizer.step) iter 3 step 1: loss 3.7451, iter time: 343.40ms iter 4 step 2: loss 3.3320, iter time: 592.14ms (optimizer.step) iter 5 step 2: loss 1.1711, iter time: 286.96ms iter 6 step 3: loss 1.4044, iter time: 450.69ms (optimizer.step) iter 7 step 3: loss 0.4987, iter time: 282.64ms iter 8 step 4: loss 0.3918, iter time: 454.12ms (optimizer.step) iter 9 step 4: loss 0.6256, iter time: 282.26ms iter 10 step 5: loss 0.5746, iter time: 454.61ms (optimizer.step) iter 11 step 5: loss 0.3998, iter time: 286.11ms iter 12 step 6: loss 0.5012, iter time: 459.76ms (optimizer.step) iter 13 step 6: loss 0.1648, iter time: 283.92ms iter 14 step 7: loss 0.1922, iter time: 454.73ms (optimizer.step) iter 15 step 7: loss 0.1334, iter time: 272.63ms iter 16 step 8: loss 0.1583, iter time: 449.95ms (optimizer.step) iter 17 step 8: loss 0.1003, iter time: 285.66ms iter 18 step 9: loss 0.1016, iter time: 452.09ms (optimizer.step) iter 19 step 9: loss 0.0799, iter time: 292.25ms iter 20 step 10: loss 0.0884, iter time: 443.14ms (optimizer.step) iter 21 step 10: loss 0.0555, iter time: 283.50ms iter 22 step 11: loss 0.0702, iter time: 448.30ms (optimizer.step) iter 23 step 11: loss 0.0724, iter time: 281.98ms iter 24 step 12: loss 0.0523, iter time: 438.57ms (optimizer.step) iter 25 step 12: loss 0.0521, iter time: 286.40ms iter 26 step 13: loss 0.0392, iter time: 471.96ms (optimizer.step) iter 27 step 13: loss 0.0480, iter time: 278.65ms iter 28 step 14: loss 0.0267, iter time: 444.58ms (optimizer.step) iter 29 step 14: loss 0.0333, iter time: 359.07ms iter 30 step 15: loss 0.0361, iter time: 456.96ms (optimizer.step) iter 31 step 15: loss 0.0237, iter time: 281.72ms iter 32 step 16: loss 0.0222, iter time: 447.11ms (optimizer.step) iter 33 step 16: loss 0.0276, iter time: 277.73ms iter 34 step 17: loss 0.0262, iter time: 447.63ms (optimizer.step) iter 35 step 17: loss 0.0296, iter time: 292.94ms iter 36 step 18: loss 0.0199, iter time: 551.45ms (optimizer.step) iter 37 step 18: loss 0.0337, iter time: 279.76ms iter 38 step 19: loss 0.0186, iter time: 436.00ms (optimizer.step) iter 39 step 19: loss 0.0167, iter time: 280.70ms iter 40 step 20: loss 0.0481, iter time: 466.32ms (optimizer.step) iter 41 step 20: loss 0.0479, iter time: 274.86ms iter 42 step 21: loss 0.0301, iter time: 450.52ms (optimizer.step) iter 43 step 21: loss 0.0252, iter time: 258.99ms iter 44 step 22: loss 0.0200, iter time: 449.49ms (optimizer.step) iter 45 step 22: loss 0.0323, iter time: 283.17ms iter 46 step 23: loss 0.0274, iter time: 449.52ms (optimizer.step) iter 47 step 23: loss 0.0151, iter time: 283.70ms iter 48 step 24: loss 0.0145, iter time: 451.78ms (optimizer.step)
lora.py with FSDP (default): train loss 0.1446 after 75 steps iter 1 step 0: loss 3.5551, iter time: 1069.53ms iter 2 step 1: loss 3.3320, iter time: 765.52ms (optimizer.step) iter 3 step 1: loss 3.1906, iter time: 537.32ms iter 4 step 2: loss 3.8844, iter time: 710.44ms (optimizer.step) iter 5 step 2: loss 3.5345, iter time: 551.50ms iter 6 step 3: loss 3.1701, iter time: 710.11ms (optimizer.step) iter 7 step 3: loss 3.6471, iter time: 516.95ms iter 8 step 4: loss 3.6926, iter time: 718.70ms (optimizer.step) iter 9 step 4: loss 3.4486, iter time: 526.94ms iter 10 step 5: loss 3.0839, iter time: 715.32ms (optimizer.step) iter 11 step 5: loss 3.3866, iter time: 528.55ms iter 12 step 6: loss 3.5413, iter time: 717.40ms (optimizer.step) iter 13 step 6: loss 3.6900, iter time: 546.27ms iter 14 step 7: loss 3.2861, iter time: 740.52ms (optimizer.step) iter 15 step 7: loss 3.2394, iter time: 532.05ms iter 16 step 8: loss 3.2060, iter time: 713.04ms (optimizer.step) iter 17 step 8: loss 3.3468, iter time: 519.96ms iter 18 step 9: loss 3.3155, iter time: 703.20ms (optimizer.step) iter 19 step 9: loss 3.4159, iter time: 529.58ms iter 20 step 10: loss 3.2537, iter time: 706.44ms (optimizer.step) iter 21 step 10: loss 2.8014, iter time: 628.57ms iter 22 step 11: loss 3.3453, iter time: 720.83ms (optimizer.step) iter 23 step 11: loss 3.0813, iter time: 527.75ms iter 24 step 12: loss 2.8985, iter time: 810.74ms (optimizer.step) iter 25 step 12: loss 3.1401, iter time: 543.45ms iter 26 step 13: loss 2.8464, iter time: 718.41ms (optimizer.step) iter 27 step 13: loss 2.5838, iter time: 534.55ms iter 28 step 14: loss 2.7362, iter time: 709.92ms (optimizer.step) iter 29 step 14: loss 2.5003, iter time: 527.97ms iter 30 step 15: loss 2.6871, iter time: 707.90ms (optimizer.step) iter 31 step 15: loss 2.5647, iter time: 530.84ms iter 32 step 16: loss 2.6898, iter time: 709.09ms (optimizer.step) iter 33 step 16: loss 2.2476, iter time: 535.74ms iter 34 step 17: loss 2.5858, iter time: 720.83ms (optimizer.step) iter 35 step 17: loss 2.6341, iter time: 552.05ms iter 36 step 18: loss 2.3383, iter time: 695.66ms (optimizer.step) iter 37 step 18: loss 2.3607, iter time: 543.67ms iter 38 step 19: loss 2.3153, iter time: 708.93ms (optimizer.step) iter 39 step 19: loss 2.4616, iter time: 526.31ms iter 40 step 20: loss 2.4130, iter time: 700.92ms (optimizer.step) iter 41 step 20: loss 2.0636, iter time: 524.18ms iter 42 step 21: loss 2.2969, iter time: 808.86ms (optimizer.step) iter 43 step 21: loss 2.0440, iter time: 532.04ms iter 44 step 22: loss 2.0349, iter time: 717.65ms (optimizer.step) iter 45 step 22: loss 2.0741, iter time: 546.68ms iter 46 step 23: loss 2.1482, iter time: 704.36ms (optimizer.step) iter 47 step 23: loss 1.8215, iter time: 635.46ms iter 48 step 24: loss 1.8571, iter time: 704.62ms (optimizer.step) iter 49 step 24: loss 1.7538, iter time: 524.68ms iter 50 step 25: loss 1.7118, iter time: 708.20ms (optimizer.step) iter 51 step 25: loss 1.7031, iter time: 528.88ms iter 52 step 26: loss 1.5699, iter time: 713.73ms (optimizer.step) iter 53 step 26: loss 1.5775, iter time: 531.15ms iter 54 step 27: loss 1.4807, iter time: 703.67ms (optimizer.step) iter 55 step 27: loss 1.3922, iter time: 539.61ms iter 56 step 28: loss 1.3600, iter time: 703.70ms (optimizer.step) iter 57 step 28: loss 1.5089, iter time: 544.63ms iter 58 step 29: loss 1.4053, iter time: 699.75ms (optimizer.step) iter 59 step 29: loss 1.2835, iter time: 539.76ms iter 60 step 30: loss 1.1636, iter time: 698.58ms (optimizer.step) iter 61 step 30: loss 1.1035, iter time: 516.37ms iter 62 step 31: loss 1.2111, iter time: 709.70ms (optimizer.step) iter 63 step 31: loss 1.0795, iter time: 541.04ms iter 64 step 32: loss 1.1503, iter time: 718.27ms (optimizer.step) iter 65 step 32: loss 1.0205, iter time: 620.26ms iter 66 step 33: loss 1.0197, iter time: 702.76ms (optimizer.step) iter 67 step 33: loss 1.0356, iter time: 531.63ms iter 68 step 34: loss 1.0187, iter time: 717.89ms (optimizer.step) iter 69 step 34: loss 0.9217, iter time: 528.73ms iter 70 step 35: loss 0.9199, iter time: 799.95ms (optimizer.step) iter 71 step 35: loss 0.8928, iter time: 521.28ms iter 72 step 36: loss 0.9298, iter time: 713.76ms (optimizer.step) iter 73 step 36: loss 0.8614, iter time: 539.65ms iter 74 step 37: loss 0.8321, iter time: 701.16ms (optimizer.step) iter 75 step 37: loss 0.8147, iter time: 530.63ms iter 76 step 38: loss 0.8113, iter time: 706.08ms (optimizer.step) iter 77 step 38: loss 0.7607, iter time: 536.43ms iter 78 step 39: loss 0.7604, iter time: 706.79ms (optimizer.step) iter 79 step 39: loss 0.7147, iter time: 513.72ms iter 80 step 40: loss 0.7155, iter time: 710.81ms (optimizer.step) iter 81 step 40: loss 0.6791, iter time: 535.22ms iter 82 step 41: loss 0.6573, iter time: 756.59ms (optimizer.step) iter 83 step 41: loss 0.6545, iter time: 539.70ms iter 84 step 42: loss 0.6304, iter time: 789.96ms (optimizer.step) iter 85 step 42: loss 0.6202, iter time: 543.96ms iter 86 step 43: loss 0.6079, iter time: 794.64ms (optimizer.step) iter 87 step 43: loss 0.5303, iter time: 529.90ms iter 88 step 44: loss 0.5536, iter time: 880.86ms (optimizer.step) iter 89 step 44: loss 0.5423, iter time: 561.47ms iter 90 step 45: loss 0.5421, iter time: 793.30ms (optimizer.step) iter 91 step 45: loss 0.5060, iter time: 555.44ms iter 92 step 46: loss 0.4453, iter time: 756.31ms (optimizer.step) iter 93 step 46: loss 0.4920, iter time: 668.88ms iter 94 step 47: loss 0.5148, iter time: 790.52ms (optimizer.step) iter 95 step 47: loss 0.3797, iter time: 543.98ms iter 96 step 48: loss 0.4327, iter time: 752.30ms (optimizer.step) iter 97 step 48: loss 0.4047, iter time: 560.33ms iter 98 step 49: loss 0.3276, iter time: 785.42ms (optimizer.step) iter 99 step 49: loss 0.3923, iter time: 539.13ms iter 100 step 50: loss 0.3058, iter time: 788.50ms (optimizer.step) iter 101 step 50: loss 0.3848, iter time: 552.10ms iter 102 step 51: loss 0.4689, iter time: 823.08ms (optimizer.step) iter 103 step 51: loss 0.2933, iter time: 561.82ms iter 104 step 52: loss 0.3563, iter time: 766.66ms (optimizer.step) iter 105 step 52: loss 0.2653, iter time: 549.85ms iter 106 step 53: loss 0.2773, iter time: 789.58ms (optimizer.step) iter 107 step 53: loss 0.2461, iter time: 546.29ms iter 108 step 54: loss 0.2620, iter time: 743.88ms (optimizer.step) iter 109 step 54: loss 0.3332, iter time: 555.79ms iter 110 step 55: loss 0.3736, iter time: 786.37ms (optimizer.step) iter 111 step 55: loss 0.2499, iter time: 630.28ms iter 112 step 56: loss 0.3041, iter time: 778.23ms (optimizer.step) iter 113 step 56: loss 0.2778, iter time: 556.51ms iter 114 step 57: loss 0.3119, iter time: 753.45ms (optimizer.step) iter 115 step 57: loss 0.2798, iter time: 549.87ms iter 116 step 58: loss 0.2519, iter time: 859.51ms (optimizer.step) iter 117 step 58: loss 0.1992, iter time: 559.52ms iter 118 step 59: loss 0.2476, iter time: 767.54ms (optimizer.step) iter 119 step 59: loss 0.2308, iter time: 534.35ms iter 120 step 60: loss 0.1909, iter time: 733.94ms (optimizer.step) iter 121 step 60: loss 0.1964, iter time: 539.13ms iter 122 step 61: loss 0.2217, iter time: 766.30ms (optimizer.step) iter 123 step 61: loss 0.2216, iter time: 565.01ms iter 124 step 62: loss 0.2126, iter time: 766.05ms (optimizer.step) iter 125 step 62: loss 0.1686, iter time: 546.63ms iter 126 step 63: loss 0.1686, iter time: 817.32ms (optimizer.step) iter 127 step 63: loss 0.2076, iter time: 542.31ms iter 128 step 64: loss 0.1669, iter time: 790.53ms (optimizer.step) iter 129 step 64: loss 0.1638, iter time: 575.63ms iter 130 step 65: loss 0.1885, iter time: 757.39ms (optimizer.step) iter 131 step 65: loss 0.1843, iter time: 543.61ms iter 132 step 66: loss 0.1722, iter time: 769.43ms (optimizer.step) iter 133 step 66: loss 0.1902, iter time: 553.14ms iter 134 step 67: loss 0.1478, iter time: 817.65ms (optimizer.step) iter 135 step 67: loss 0.1494, iter time: 546.42ms iter 136 step 68: loss 0.1705, iter time: 758.27ms (optimizer.step) iter 137 step 68: loss 0.1625, iter time: 567.57ms iter 138 step 69: loss 0.1420, iter time: 787.70ms (optimizer.step) iter 139 step 69: loss 0.1636, iter time: 678.68ms iter 140 step 70: loss 0.1320, iter time: 794.51ms (optimizer.step) iter 141 step 70: loss 0.1359, iter time: 535.61ms iter 142 step 71: loss 0.1478, iter time: 802.98ms (optimizer.step) iter 143 step 71: loss 0.1515, iter time: 529.23ms iter 144 step 72: loss 0.1419, iter time: 757.31ms (optimizer.step) iter 145 step 72: loss 0.1396, iter time: 548.87ms iter 146 step 73: loss 0.1358, iter time: 769.09ms (optimizer.step) iter 147 step 73: loss 0.1397, iter time: 531.65ms iter 148 step 74: loss 0.1359, iter time: 755.25ms (optimizer.step) iter 149 step 74: loss 0.1334, iter time: 541.91ms iter 150 step 75: loss 0.1446, iter time: 806.27ms (optimizer.step)

However, forcing strategy='ddp' with the same previous setting seem to work only for the full trainer (note how with LoRa the loss seem to be stuck around 3.6). I've tried to play with the hyperparameters without luck:

full.py with DDP: train loss 0.0522 after 24 steps
iter 1 step 0: loss 3.5551, iter time: 567.18ms iter 2 step 1: loss 3.8844, iter time: 796.59ms (optimizer.step) iter 3 step 1: loss 3.7451, iter time: 76.60ms iter 4 step 2: loss 3.3320, iter time: 469.35ms (optimizer.step) iter 5 step 2: loss 1.1714, iter time: 73.10ms iter 6 step 3: loss 1.4070, iter time: 442.38ms (optimizer.step) iter 7 step 3: loss 0.4997, iter time: 70.98ms iter 8 step 4: loss 0.3922, iter time: 446.22ms (optimizer.step) iter 9 step 4: loss 0.6224, iter time: 69.96ms iter 10 step 5: loss 0.5729, iter time: 446.39ms (optimizer.step) iter 11 step 5: loss 0.4016, iter time: 70.05ms iter 12 step 6: loss 0.5031, iter time: 436.62ms (optimizer.step) iter 13 step 6: loss 0.1636, iter time: 72.35ms iter 14 step 7: loss 0.1915, iter time: 440.08ms (optimizer.step) iter 15 step 7: loss 0.1349, iter time: 73.24ms iter 16 step 8: loss 0.1595, iter time: 444.82ms (optimizer.step) iter 17 step 8: loss 0.0983, iter time: 71.23ms iter 18 step 9: loss 0.1043, iter time: 444.88ms (optimizer.step) iter 19 step 9: loss 0.0808, iter time: 69.35ms iter 20 step 10: loss 0.0890, iter time: 442.52ms (optimizer.step) iter 21 step 10: loss 0.0552, iter time: 71.58ms iter 22 step 11: loss 0.0709, iter time: 443.41ms (optimizer.step) iter 23 step 11: loss 0.0729, iter time: 71.70ms iter 24 step 12: loss 0.0522, iter time: 442.91ms (optimizer.step)
lora.py with DDP: train loss 3.7451 after 100 steps
iter 1 step 0: loss 3.5551, iter time: 167.67ms iter 2 step 1: loss 3.3320, iter time: 5421.28ms (optimizer.step) iter 3 step 1: loss 3.1906, iter time: 94.70ms iter 4 step 2: loss 3.8844, iter time: 244.89ms (optimizer.step) iter 5 step 2: loss 3.5551, iter time: 94.09ms iter 6 step 3: loss 3.1906, iter time: 236.35ms (optimizer.step) iter 7 step 3: loss 3.7000, iter time: 91.63ms iter 8 step 4: loss 3.7451, iter time: 207.49ms (optimizer.step) iter 9 step 4: loss 3.5551, iter time: 93.52ms iter 10 step 5: loss 3.1906, iter time: 214.52ms (optimizer.step) iter 11 step 5: loss 3.5551, iter time: 97.39ms iter 12 step 6: loss 3.7000, iter time: 205.80ms (optimizer.step) iter 13 step 6: loss 3.9202, iter time: 93.07ms iter 14 step 7: loss 3.5086, iter time: 201.63ms (optimizer.step) iter 15 step 7: loss 3.5551, iter time: 99.50ms iter 16 step 8: loss 3.5086, iter time: 201.73ms (optimizer.step) iter 17 step 8: loss 3.7451, iter time: 94.45ms iter 18 step 9: loss 3.7000, iter time: 209.46ms (optimizer.step) iter 19 step 9: loss 3.8844, iter time: 91.71ms iter 20 step 10: loss 3.7451, iter time: 207.51ms (optimizer.step) iter 21 step 10: loss 3.3320, iter time: 93.41ms iter 22 step 11: loss 3.9202, iter time: 203.45ms (optimizer.step) iter 23 step 11: loss 3.7451, iter time: 91.74ms iter 24 step 12: loss 3.5073, iter time: 207.43ms (optimizer.step) iter 25 step 12: loss 3.8844, iter time: 90.37ms iter 26 step 13: loss 3.5551, iter time: 208.81ms (optimizer.step) iter 27 step 13: loss 3.3320, iter time: 96.21ms iter 28 step 14: loss 3.5073, iter time: 198.56ms (optimizer.step) iter 29 step 14: loss 3.3320, iter time: 98.13ms iter 30 step 15: loss 3.5551, iter time: 200.53ms (optimizer.step) iter 31 step 15: loss 3.5073, iter time: 97.73ms iter 32 step 16: loss 3.7000, iter time: 199.27ms (optimizer.step) iter 33 step 16: loss 3.1906, iter time: 93.94ms iter 34 step 17: loss 3.7000, iter time: 205.37ms (optimizer.step) iter 35 step 17: loss 3.9202, iter time: 94.18ms iter 36 step 18: loss 3.5086, iter time: 203.05ms (optimizer.step) iter 37 step 18: loss 3.7000, iter time: 90.91ms iter 38 step 19: loss 3.5551, iter time: 210.29ms (optimizer.step) iter 39 step 19: loss 3.9545, iter time: 93.97ms iter 40 step 20: loss 3.9202, iter time: 200.84ms (optimizer.step) iter 41 step 20: loss 3.5073, iter time: 93.15ms iter 42 step 21: loss 3.9202, iter time: 208.65ms (optimizer.step) iter 43 step 21: loss 3.7000, iter time: 92.60ms iter 44 step 22: loss 3.5551, iter time: 202.50ms (optimizer.step) iter 45 step 22: loss 3.9202, iter time: 109.24ms iter 46 step 23: loss 3.9545, iter time: 189.96ms (optimizer.step) iter 47 step 23: loss 3.7000, iter time: 108.17ms iter 48 step 24: loss 3.5551, iter time: 190.84ms (optimizer.step) iter 49 step 24: loss 3.8844, iter time: 108.57ms iter 50 step 25: loss 3.7000, iter time: 195.16ms (optimizer.step) iter 51 step 25: loss 3.7451, iter time: 108.96ms iter 52 step 26: loss 3.5073, iter time: 189.34ms (optimizer.step) iter 53 step 26: loss 3.5551, iter time: 108.77ms iter 54 step 27: loss 3.7000, iter time: 190.42ms (optimizer.step) iter 55 step 27: loss 3.8844, iter time: 108.43ms iter 56 step 28: loss 3.5073, iter time: 192.47ms (optimizer.step) iter 57 step 28: loss 3.9545, iter time: 105.81ms iter 58 step 29: loss 3.9202, iter time: 195.24ms (optimizer.step) iter 59 step 29: loss 3.5551, iter time: 109.79ms iter 60 step 30: loss 3.3320, iter time: 191.49ms (optimizer.step) iter 61 step 30: loss 3.3320, iter time: 108.14ms iter 62 step 31: loss 3.5551, iter time: 188.55ms (optimizer.step) iter 63 step 31: loss 3.1906, iter time: 109.10ms iter 64 step 32: loss 3.7451, iter time: 189.91ms (optimizer.step) iter 65 step 32: loss 3.5073, iter time: 107.88ms iter 66 step 33: loss 3.5086, iter time: 188.76ms (optimizer.step) iter 67 step 33: loss 3.7451, iter time: 107.70ms iter 68 step 34: loss 3.9202, iter time: 188.97ms (optimizer.step) iter 69 step 34: loss 3.7000, iter time: 108.00ms iter 70 step 35: loss 3.7000, iter time: 190.73ms (optimizer.step) iter 71 step 35: loss 3.5086, iter time: 108.37ms iter 72 step 36: loss 3.5551, iter time: 188.15ms (optimizer.step) iter 73 step 36: loss 3.9202, iter time: 108.60ms iter 74 step 37: loss 3.7000, iter time: 190.43ms (optimizer.step) iter 75 step 37: loss 3.9202, iter time: 108.32ms iter 76 step 38: loss 3.9202, iter time: 195.27ms (optimizer.step) iter 77 step 38: loss 3.7451, iter time: 105.94ms iter 78 step 39: loss 3.3320, iter time: 193.23ms (optimizer.step) iter 79 step 39: loss 3.3320, iter time: 105.46ms iter 80 step 40: loss 3.3320, iter time: 195.81ms (optimizer.step) iter 81 step 40: loss 3.3320, iter time: 109.03ms iter 82 step 41: loss 3.5086, iter time: 189.86ms (optimizer.step) iter 83 step 41: loss 3.5551, iter time: 108.57ms iter 84 step 42: loss 3.3320, iter time: 188.29ms (optimizer.step) iter 85 step 42: loss 3.5551, iter time: 108.68ms iter 86 step 43: loss 3.9202, iter time: 192.64ms (optimizer.step) iter 87 step 43: loss 3.7000, iter time: 108.35ms iter 88 step 44: loss 3.3320, iter time: 197.48ms (optimizer.step) iter 89 step 44: loss 3.8844, iter time: 107.49ms iter 90 step 45: loss 3.5551, iter time: 182.41ms (optimizer.step) iter 91 step 45: loss 3.5551, iter time: 108.68ms iter 92 step 46: loss 3.7451, iter time: 197.15ms (optimizer.step) iter 93 step 46: loss 3.5073, iter time: 109.28ms iter 94 step 47: loss 3.1906, iter time: 185.06ms (optimizer.step) iter 95 step 47: loss 3.7451, iter time: 105.93ms iter 96 step 48: loss 3.5551, iter time: 188.47ms (optimizer.step) iter 97 step 48: loss 3.8844, iter time: 107.85ms iter 98 step 49: loss 3.9545, iter time: 193.28ms (optimizer.step) iter 99 step 49: loss 3.5551, iter time: 106.31ms iter 100 step 50: loss 3.7451, iter time: 187.97ms (optimizer.step) iter 101 step 50: loss 3.5551, iter time: 108.78ms iter 102 step 51: loss 3.1906, iter time: 196.97ms (optimizer.step) iter 103 step 51: loss 3.5086, iter time: 108.74ms iter 104 step 52: loss 3.5551, iter time: 190.11ms (optimizer.step) iter 105 step 52: loss 3.7451, iter time: 107.93ms iter 106 step 53: loss 3.7000, iter time: 190.80ms (optimizer.step) iter 107 step 53: loss 3.5086, iter time: 106.93ms iter 108 step 54: loss 3.7000, iter time: 189.49ms (optimizer.step) iter 109 step 54: loss 3.5073, iter time: 107.82ms iter 110 step 55: loss 3.1906, iter time: 188.76ms (optimizer.step) iter 111 step 55: loss 3.9202, iter time: 108.66ms iter 112 step 56: loss 3.5073, iter time: 190.35ms (optimizer.step) iter 113 step 56: loss 3.5073, iter time: 109.09ms iter 114 step 57: loss 3.1906, iter time: 194.27ms (optimizer.step) iter 115 step 57: loss 3.3320, iter time: 105.29ms iter 116 step 58: loss 3.5551, iter time: 193.18ms (optimizer.step) iter 117 step 58: loss 3.9202, iter time: 106.85ms iter 118 step 59: loss 3.5073, iter time: 194.52ms (optimizer.step) iter 119 step 59: loss 3.5073, iter time: 109.06ms iter 120 step 60: loss 3.8844, iter time: 189.87ms (optimizer.step) iter 121 step 60: loss 3.7451, iter time: 107.30ms iter 122 step 61: loss 3.5551, iter time: 189.06ms (optimizer.step) iter 123 step 61: loss 3.3320, iter time: 108.87ms iter 124 step 62: loss 3.5551, iter time: 187.82ms (optimizer.step) iter 125 step 62: loss 3.7451, iter time: 107.67ms iter 126 step 63: loss 3.7451, iter time: 193.35ms (optimizer.step) iter 127 step 63: loss 3.5551, iter time: 107.17ms iter 128 step 64: loss 3.7000, iter time: 187.17ms (optimizer.step) iter 129 step 64: loss 3.7000, iter time: 108.06ms iter 130 step 65: loss 3.5551, iter time: 191.12ms (optimizer.step) iter 131 step 65: loss 3.3320, iter time: 108.52ms iter 132 step 66: loss 3.5551, iter time: 196.97ms (optimizer.step) iter 133 step 66: loss 3.1906, iter time: 108.26ms iter 134 step 67: loss 3.7451, iter time: 190.88ms (optimizer.step) iter 135 step 67: loss 3.7451, iter time: 107.66ms iter 136 step 68: loss 3.8844, iter time: 191.31ms (optimizer.step) iter 137 step 68: loss 3.5073, iter time: 107.27ms iter 138 step 69: loss 3.5551, iter time: 189.66ms (optimizer.step) iter 139 step 69: loss 3.8844, iter time: 108.03ms iter 140 step 70: loss 3.9545, iter time: 193.45ms (optimizer.step) iter 141 step 70: loss 3.9545, iter time: 106.10ms iter 142 step 71: loss 3.7451, iter time: 195.23ms (optimizer.step) iter 143 step 71: loss 3.5086, iter time: 107.01ms iter 144 step 72: loss 3.9202, iter time: 194.32ms (optimizer.step) iter 145 step 72: loss 3.7000, iter time: 108.96ms iter 146 step 73: loss 3.3320, iter time: 190.23ms (optimizer.step) iter 147 step 73: loss 3.5551, iter time: 107.08ms iter 148 step 74: loss 3.8844, iter time: 191.91ms (optimizer.step) iter 149 step 74: loss 3.3320, iter time: 106.69ms iter 150 step 75: loss 3.5551, iter time: 187.99ms (optimizer.step) iter 151 step 75: loss 3.5086, iter time: 107.73ms iter 152 step 76: loss 3.9545, iter time: 189.14ms (optimizer.step) iter 153 step 76: loss 3.1906, iter time: 107.78ms iter 154 step 77: loss 3.7451, iter time: 189.06ms (optimizer.step) iter 155 step 77: loss 3.8844, iter time: 108.37ms iter 156 step 78: loss 3.5073, iter time: 190.49ms (optimizer.step) iter 157 step 78: loss 3.9202, iter time: 107.38ms iter 158 step 79: loss 3.5073, iter time: 194.00ms (optimizer.step) iter 159 step 79: loss 3.5551, iter time: 106.31ms iter 160 step 80: loss 3.8844, iter time: 192.46ms (optimizer.step) iter 161 step 80: loss 3.7451, iter time: 106.95ms iter 162 step 81: loss 3.9202, iter time: 194.22ms (optimizer.step) iter 163 step 81: loss 3.7451, iter time: 108.14ms iter 164 step 82: loss 3.9545, iter time: 190.68ms (optimizer.step) iter 165 step 82: loss 3.7451, iter time: 107.95ms iter 166 step 83: loss 3.7451, iter time: 190.76ms (optimizer.step) iter 167 step 83: loss 3.7451, iter time: 107.49ms iter 168 step 84: loss 3.9545, iter time: 197.79ms (optimizer.step) iter 169 step 84: loss 3.3320, iter time: 108.76ms iter 170 step 85: loss 3.7451, iter time: 189.47ms (optimizer.step) iter 171 step 85: loss 3.7000, iter time: 107.84ms iter 172 step 86: loss 3.5073, iter time: 189.00ms (optimizer.step) iter 173 step 86: loss 3.9545, iter time: 108.78ms iter 174 step 87: loss 3.5086, iter time: 190.17ms (optimizer.step) iter 175 step 87: loss 3.8844, iter time: 107.21ms iter 176 step 88: loss 3.1906, iter time: 196.45ms (optimizer.step) iter 177 step 88: loss 3.3320, iter time: 107.93ms iter 178 step 89: loss 3.3320, iter time: 191.34ms (optimizer.step) iter 179 step 89: loss 3.5551, iter time: 108.91ms iter 180 step 90: loss 3.5073, iter time: 192.33ms (optimizer.step) iter 181 step 90: loss 3.5073, iter time: 107.72ms iter 182 step 91: loss 3.7000, iter time: 191.23ms (optimizer.step) iter 183 step 91: loss 3.5073, iter time: 107.90ms iter 184 step 92: loss 3.7000, iter time: 189.10ms (optimizer.step) iter 185 step 92: loss 3.7451, iter time: 108.30ms iter 186 step 93: loss 3.5073, iter time: 190.84ms (optimizer.step) iter 187 step 93: loss 3.7451, iter time: 108.35ms iter 188 step 94: loss 3.5086, iter time: 192.90ms (optimizer.step) iter 189 step 94: loss 3.7451, iter time: 109.55ms iter 190 step 95: loss 3.5073, iter time: 184.05ms (optimizer.step) iter 191 step 95: loss 3.5086, iter time: 108.47ms iter 192 step 96: loss 3.5551, iter time: 190.71ms (optimizer.step) iter 193 step 96: loss 3.9202, iter time: 107.94ms iter 194 step 97: loss 3.7000, iter time: 186.67ms (optimizer.step) iter 195 step 97: loss 3.7451, iter time: 107.74ms iter 196 step 98: loss 3.7000, iter time: 193.63ms (optimizer.step) iter 197 step 98: loss 3.5073, iter time: 109.04ms iter 198 step 99: loss 3.7451, iter time: 189.90ms (optimizer.step) iter 199 step 99: loss 3.3320, iter time: 108.26ms iter 200 step 100: loss 3.7451, iter time: 188.62ms (optimizer.step)

Is this an expected behaviour (maybe related to the LoRa implementation)?

v-dicicco avatar Dec 18 '23 17:12 v-dicicco

Can you share a branch from a fork or a diff with all the changes that you've made to the repository and the commands to repro your results?

carmocca avatar Dec 18 '23 19:12 carmocca

Sure, here are the steps to reproduce on a fresh lit-gpt clone:

# download & convert model
python scripts/download.py --repo_id TinyLlama/TinyLlama-1.1B-intermediate-step-955k-token-2T
python scripts/convert_hf_checkpoint.py --checkpoint checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-955k-token-2T/
# prepare (tiny) train dataset
python scripts/prepare_alpaca.py --checkpoint_dir checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-955k-token-2T/ --test_split_fraction 0.0003865
cp data/alpaca/test.pt data/alpaca/train.pt

force DDP with finetune/lora.py, use 2 devices (note: I'm using 2 A40), and decrease batch_size (just to increase the frequency of updates given the small dataset, I've tried with different values obtaining the same behaviour):

diff --git a/finetune/lora.py b/finetune/lora.py
index 852cf01..6255c19 100644
--- a/finetune/lora.py
+++ b/finetune/lora.py
@@ -32,11 +32,11 @@ save_interval = 100
 eval_iters = 100
 eval_max_new_tokens = 100
 log_interval = 1
-devices = 1
+devices = 2

 # Hyperparameters
 learning_rate = 3e-4
-batch_size = 128
+batch_size = 8
 micro_batch_size = 4
 gradient_accumulation_iters = batch_size // micro_batch_size
 assert gradient_accumulation_iters > 0
@@ -90,6 +90,7 @@ def setup(
     else:
         strategy = "auto"

+    strategy="ddp"
     logger = CSVLogger(out_dir.parent, out_dir.name, flush_logs_every_n_steps=log_interval)
     fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger, plugins=plugins)
     fabric.print(hparams)
# run LoRa training
CUDA_VISIBLE_DEVICES=1,2 python finetune/lora.py --checkpoint_dir checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-955k-token-2T/

Observe the loss that does not improve.

On the other hand, if you do:

diff --git a/finetune/full.py b/finetune/full.py
index b88d19a..35d1b6e 100644
--- a/finetune/full.py
+++ b/finetune/full.py
@@ -31,12 +31,12 @@ save_interval = 1000
 eval_iters = 100
 eval_max_new_tokens = 100
 log_interval = 1
-devices = 1
+devices = 2

 # Hyperparameters
 learning_rate = 3e-3
-batch_size = 64 / devices
-micro_batch_size = 1
+batch_size = 8
+micro_batch_size = 4
 gradient_accumulation_iters = batch_size // micro_batch_size
 assert gradient_accumulation_iters > 0
 max_seq_length = None  # assign value to truncate
@@ -69,6 +69,7 @@ def setup(
     else:
         strategy = "auto"

+    strategy="ddp"
     logger = CSVLogger(out_dir.parent, out_dir.name, flush_logs_every_n_steps=log_interval)
     fabric = L.Fabric(devices=fabric_devices, strategy=strategy, precision=precision, loggers=logger)
     fabric.print(hparams)
# run full training
CUDA_VISIBLE_DEVICES=1,2 python finetune/full.py --checkpoint_dir checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-955k-token-2T/

Model will overfit as expected.

v-dicicco avatar Dec 18 '23 20:12 v-dicicco

Hey @v-dicicco and @carmocca, Do you have any update on how/if you were able to solve this? I have the same problem, DDP strategy gets worser performance compared to training on a single GPU

Srijith-rkr avatar Mar 06 '24 14:03 Srijith-rkr