Unable to `finetune/lora.py` with `DDP`
Hello, I'm trying different distributed training strategies changing Fabric's strategy argument to different values (as listed here).
To sanity check, I'm verifying the training is able to overfit a small model (TinyLlama) with a really toy setting: a dataset of just 20 Alpaca samples, batch_size=6, micro_batch_size=2 and two devices (GPUs).
Using the default FSDP strategy you can clearly see the model is training, quickly reaching small loss values:
full.py with FSDP (default): train loss 0.0145 after 24 steps
iter 1 step 0: loss 3.5551, iter time: 832.90ms
iter 2 step 1: loss 3.8844, iter time: 495.84ms (optimizer.step)
iter 3 step 1: loss 3.7451, iter time: 343.40ms
iter 4 step 2: loss 3.3320, iter time: 592.14ms (optimizer.step)
iter 5 step 2: loss 1.1711, iter time: 286.96ms
iter 6 step 3: loss 1.4044, iter time: 450.69ms (optimizer.step)
iter 7 step 3: loss 0.4987, iter time: 282.64ms
iter 8 step 4: loss 0.3918, iter time: 454.12ms (optimizer.step)
iter 9 step 4: loss 0.6256, iter time: 282.26ms
iter 10 step 5: loss 0.5746, iter time: 454.61ms (optimizer.step)
iter 11 step 5: loss 0.3998, iter time: 286.11ms
iter 12 step 6: loss 0.5012, iter time: 459.76ms (optimizer.step)
iter 13 step 6: loss 0.1648, iter time: 283.92ms
iter 14 step 7: loss 0.1922, iter time: 454.73ms (optimizer.step)
iter 15 step 7: loss 0.1334, iter time: 272.63ms
iter 16 step 8: loss 0.1583, iter time: 449.95ms (optimizer.step)
iter 17 step 8: loss 0.1003, iter time: 285.66ms
iter 18 step 9: loss 0.1016, iter time: 452.09ms (optimizer.step)
iter 19 step 9: loss 0.0799, iter time: 292.25ms
iter 20 step 10: loss 0.0884, iter time: 443.14ms (optimizer.step)
iter 21 step 10: loss 0.0555, iter time: 283.50ms
iter 22 step 11: loss 0.0702, iter time: 448.30ms (optimizer.step)
iter 23 step 11: loss 0.0724, iter time: 281.98ms
iter 24 step 12: loss 0.0523, iter time: 438.57ms (optimizer.step)
iter 25 step 12: loss 0.0521, iter time: 286.40ms
iter 26 step 13: loss 0.0392, iter time: 471.96ms (optimizer.step)
iter 27 step 13: loss 0.0480, iter time: 278.65ms
iter 28 step 14: loss 0.0267, iter time: 444.58ms (optimizer.step)
iter 29 step 14: loss 0.0333, iter time: 359.07ms
iter 30 step 15: loss 0.0361, iter time: 456.96ms (optimizer.step)
iter 31 step 15: loss 0.0237, iter time: 281.72ms
iter 32 step 16: loss 0.0222, iter time: 447.11ms (optimizer.step)
iter 33 step 16: loss 0.0276, iter time: 277.73ms
iter 34 step 17: loss 0.0262, iter time: 447.63ms (optimizer.step)
iter 35 step 17: loss 0.0296, iter time: 292.94ms
iter 36 step 18: loss 0.0199, iter time: 551.45ms (optimizer.step)
iter 37 step 18: loss 0.0337, iter time: 279.76ms
iter 38 step 19: loss 0.0186, iter time: 436.00ms (optimizer.step)
iter 39 step 19: loss 0.0167, iter time: 280.70ms
iter 40 step 20: loss 0.0481, iter time: 466.32ms (optimizer.step)
iter 41 step 20: loss 0.0479, iter time: 274.86ms
iter 42 step 21: loss 0.0301, iter time: 450.52ms (optimizer.step)
iter 43 step 21: loss 0.0252, iter time: 258.99ms
iter 44 step 22: loss 0.0200, iter time: 449.49ms (optimizer.step)
iter 45 step 22: loss 0.0323, iter time: 283.17ms
iter 46 step 23: loss 0.0274, iter time: 449.52ms (optimizer.step)
iter 47 step 23: loss 0.0151, iter time: 283.70ms
iter 48 step 24: loss 0.0145, iter time: 451.78ms (optimizer.step)
lora.py with FSDP (default): train loss 0.1446 after 75 steps
iter 1 step 0: loss 3.5551, iter time: 1069.53ms
iter 2 step 1: loss 3.3320, iter time: 765.52ms (optimizer.step)
iter 3 step 1: loss 3.1906, iter time: 537.32ms
iter 4 step 2: loss 3.8844, iter time: 710.44ms (optimizer.step)
iter 5 step 2: loss 3.5345, iter time: 551.50ms
iter 6 step 3: loss 3.1701, iter time: 710.11ms (optimizer.step)
iter 7 step 3: loss 3.6471, iter time: 516.95ms
iter 8 step 4: loss 3.6926, iter time: 718.70ms (optimizer.step)
iter 9 step 4: loss 3.4486, iter time: 526.94ms
iter 10 step 5: loss 3.0839, iter time: 715.32ms (optimizer.step)
iter 11 step 5: loss 3.3866, iter time: 528.55ms
iter 12 step 6: loss 3.5413, iter time: 717.40ms (optimizer.step)
iter 13 step 6: loss 3.6900, iter time: 546.27ms
iter 14 step 7: loss 3.2861, iter time: 740.52ms (optimizer.step)
iter 15 step 7: loss 3.2394, iter time: 532.05ms
iter 16 step 8: loss 3.2060, iter time: 713.04ms (optimizer.step)
iter 17 step 8: loss 3.3468, iter time: 519.96ms
iter 18 step 9: loss 3.3155, iter time: 703.20ms (optimizer.step)
iter 19 step 9: loss 3.4159, iter time: 529.58ms
iter 20 step 10: loss 3.2537, iter time: 706.44ms (optimizer.step)
iter 21 step 10: loss 2.8014, iter time: 628.57ms
iter 22 step 11: loss 3.3453, iter time: 720.83ms (optimizer.step)
iter 23 step 11: loss 3.0813, iter time: 527.75ms
iter 24 step 12: loss 2.8985, iter time: 810.74ms (optimizer.step)
iter 25 step 12: loss 3.1401, iter time: 543.45ms
iter 26 step 13: loss 2.8464, iter time: 718.41ms (optimizer.step)
iter 27 step 13: loss 2.5838, iter time: 534.55ms
iter 28 step 14: loss 2.7362, iter time: 709.92ms (optimizer.step)
iter 29 step 14: loss 2.5003, iter time: 527.97ms
iter 30 step 15: loss 2.6871, iter time: 707.90ms (optimizer.step)
iter 31 step 15: loss 2.5647, iter time: 530.84ms
iter 32 step 16: loss 2.6898, iter time: 709.09ms (optimizer.step)
iter 33 step 16: loss 2.2476, iter time: 535.74ms
iter 34 step 17: loss 2.5858, iter time: 720.83ms (optimizer.step)
iter 35 step 17: loss 2.6341, iter time: 552.05ms
iter 36 step 18: loss 2.3383, iter time: 695.66ms (optimizer.step)
iter 37 step 18: loss 2.3607, iter time: 543.67ms
iter 38 step 19: loss 2.3153, iter time: 708.93ms (optimizer.step)
iter 39 step 19: loss 2.4616, iter time: 526.31ms
iter 40 step 20: loss 2.4130, iter time: 700.92ms (optimizer.step)
iter 41 step 20: loss 2.0636, iter time: 524.18ms
iter 42 step 21: loss 2.2969, iter time: 808.86ms (optimizer.step)
iter 43 step 21: loss 2.0440, iter time: 532.04ms
iter 44 step 22: loss 2.0349, iter time: 717.65ms (optimizer.step)
iter 45 step 22: loss 2.0741, iter time: 546.68ms
iter 46 step 23: loss 2.1482, iter time: 704.36ms (optimizer.step)
iter 47 step 23: loss 1.8215, iter time: 635.46ms
iter 48 step 24: loss 1.8571, iter time: 704.62ms (optimizer.step)
iter 49 step 24: loss 1.7538, iter time: 524.68ms
iter 50 step 25: loss 1.7118, iter time: 708.20ms (optimizer.step)
iter 51 step 25: loss 1.7031, iter time: 528.88ms
iter 52 step 26: loss 1.5699, iter time: 713.73ms (optimizer.step)
iter 53 step 26: loss 1.5775, iter time: 531.15ms
iter 54 step 27: loss 1.4807, iter time: 703.67ms (optimizer.step)
iter 55 step 27: loss 1.3922, iter time: 539.61ms
iter 56 step 28: loss 1.3600, iter time: 703.70ms (optimizer.step)
iter 57 step 28: loss 1.5089, iter time: 544.63ms
iter 58 step 29: loss 1.4053, iter time: 699.75ms (optimizer.step)
iter 59 step 29: loss 1.2835, iter time: 539.76ms
iter 60 step 30: loss 1.1636, iter time: 698.58ms (optimizer.step)
iter 61 step 30: loss 1.1035, iter time: 516.37ms
iter 62 step 31: loss 1.2111, iter time: 709.70ms (optimizer.step)
iter 63 step 31: loss 1.0795, iter time: 541.04ms
iter 64 step 32: loss 1.1503, iter time: 718.27ms (optimizer.step)
iter 65 step 32: loss 1.0205, iter time: 620.26ms
iter 66 step 33: loss 1.0197, iter time: 702.76ms (optimizer.step)
iter 67 step 33: loss 1.0356, iter time: 531.63ms
iter 68 step 34: loss 1.0187, iter time: 717.89ms (optimizer.step)
iter 69 step 34: loss 0.9217, iter time: 528.73ms
iter 70 step 35: loss 0.9199, iter time: 799.95ms (optimizer.step)
iter 71 step 35: loss 0.8928, iter time: 521.28ms
iter 72 step 36: loss 0.9298, iter time: 713.76ms (optimizer.step)
iter 73 step 36: loss 0.8614, iter time: 539.65ms
iter 74 step 37: loss 0.8321, iter time: 701.16ms (optimizer.step)
iter 75 step 37: loss 0.8147, iter time: 530.63ms
iter 76 step 38: loss 0.8113, iter time: 706.08ms (optimizer.step)
iter 77 step 38: loss 0.7607, iter time: 536.43ms
iter 78 step 39: loss 0.7604, iter time: 706.79ms (optimizer.step)
iter 79 step 39: loss 0.7147, iter time: 513.72ms
iter 80 step 40: loss 0.7155, iter time: 710.81ms (optimizer.step)
iter 81 step 40: loss 0.6791, iter time: 535.22ms
iter 82 step 41: loss 0.6573, iter time: 756.59ms (optimizer.step)
iter 83 step 41: loss 0.6545, iter time: 539.70ms
iter 84 step 42: loss 0.6304, iter time: 789.96ms (optimizer.step)
iter 85 step 42: loss 0.6202, iter time: 543.96ms
iter 86 step 43: loss 0.6079, iter time: 794.64ms (optimizer.step)
iter 87 step 43: loss 0.5303, iter time: 529.90ms
iter 88 step 44: loss 0.5536, iter time: 880.86ms (optimizer.step)
iter 89 step 44: loss 0.5423, iter time: 561.47ms
iter 90 step 45: loss 0.5421, iter time: 793.30ms (optimizer.step)
iter 91 step 45: loss 0.5060, iter time: 555.44ms
iter 92 step 46: loss 0.4453, iter time: 756.31ms (optimizer.step)
iter 93 step 46: loss 0.4920, iter time: 668.88ms
iter 94 step 47: loss 0.5148, iter time: 790.52ms (optimizer.step)
iter 95 step 47: loss 0.3797, iter time: 543.98ms
iter 96 step 48: loss 0.4327, iter time: 752.30ms (optimizer.step)
iter 97 step 48: loss 0.4047, iter time: 560.33ms
iter 98 step 49: loss 0.3276, iter time: 785.42ms (optimizer.step)
iter 99 step 49: loss 0.3923, iter time: 539.13ms
iter 100 step 50: loss 0.3058, iter time: 788.50ms (optimizer.step)
iter 101 step 50: loss 0.3848, iter time: 552.10ms
iter 102 step 51: loss 0.4689, iter time: 823.08ms (optimizer.step)
iter 103 step 51: loss 0.2933, iter time: 561.82ms
iter 104 step 52: loss 0.3563, iter time: 766.66ms (optimizer.step)
iter 105 step 52: loss 0.2653, iter time: 549.85ms
iter 106 step 53: loss 0.2773, iter time: 789.58ms (optimizer.step)
iter 107 step 53: loss 0.2461, iter time: 546.29ms
iter 108 step 54: loss 0.2620, iter time: 743.88ms (optimizer.step)
iter 109 step 54: loss 0.3332, iter time: 555.79ms
iter 110 step 55: loss 0.3736, iter time: 786.37ms (optimizer.step)
iter 111 step 55: loss 0.2499, iter time: 630.28ms
iter 112 step 56: loss 0.3041, iter time: 778.23ms (optimizer.step)
iter 113 step 56: loss 0.2778, iter time: 556.51ms
iter 114 step 57: loss 0.3119, iter time: 753.45ms (optimizer.step)
iter 115 step 57: loss 0.2798, iter time: 549.87ms
iter 116 step 58: loss 0.2519, iter time: 859.51ms (optimizer.step)
iter 117 step 58: loss 0.1992, iter time: 559.52ms
iter 118 step 59: loss 0.2476, iter time: 767.54ms (optimizer.step)
iter 119 step 59: loss 0.2308, iter time: 534.35ms
iter 120 step 60: loss 0.1909, iter time: 733.94ms (optimizer.step)
iter 121 step 60: loss 0.1964, iter time: 539.13ms
iter 122 step 61: loss 0.2217, iter time: 766.30ms (optimizer.step)
iter 123 step 61: loss 0.2216, iter time: 565.01ms
iter 124 step 62: loss 0.2126, iter time: 766.05ms (optimizer.step)
iter 125 step 62: loss 0.1686, iter time: 546.63ms
iter 126 step 63: loss 0.1686, iter time: 817.32ms (optimizer.step)
iter 127 step 63: loss 0.2076, iter time: 542.31ms
iter 128 step 64: loss 0.1669, iter time: 790.53ms (optimizer.step)
iter 129 step 64: loss 0.1638, iter time: 575.63ms
iter 130 step 65: loss 0.1885, iter time: 757.39ms (optimizer.step)
iter 131 step 65: loss 0.1843, iter time: 543.61ms
iter 132 step 66: loss 0.1722, iter time: 769.43ms (optimizer.step)
iter 133 step 66: loss 0.1902, iter time: 553.14ms
iter 134 step 67: loss 0.1478, iter time: 817.65ms (optimizer.step)
iter 135 step 67: loss 0.1494, iter time: 546.42ms
iter 136 step 68: loss 0.1705, iter time: 758.27ms (optimizer.step)
iter 137 step 68: loss 0.1625, iter time: 567.57ms
iter 138 step 69: loss 0.1420, iter time: 787.70ms (optimizer.step)
iter 139 step 69: loss 0.1636, iter time: 678.68ms
iter 140 step 70: loss 0.1320, iter time: 794.51ms (optimizer.step)
iter 141 step 70: loss 0.1359, iter time: 535.61ms
iter 142 step 71: loss 0.1478, iter time: 802.98ms (optimizer.step)
iter 143 step 71: loss 0.1515, iter time: 529.23ms
iter 144 step 72: loss 0.1419, iter time: 757.31ms (optimizer.step)
iter 145 step 72: loss 0.1396, iter time: 548.87ms
iter 146 step 73: loss 0.1358, iter time: 769.09ms (optimizer.step)
iter 147 step 73: loss 0.1397, iter time: 531.65ms
iter 148 step 74: loss 0.1359, iter time: 755.25ms (optimizer.step)
iter 149 step 74: loss 0.1334, iter time: 541.91ms
iter 150 step 75: loss 0.1446, iter time: 806.27ms (optimizer.step)
However, forcing strategy='ddp' with the same previous setting seem to work only for the full trainer (note how with LoRa the loss seem to be stuck around 3.6). I've tried to play with the hyperparameters without luck:
full.py with DDP: train loss 0.0522 after 24 steps
iter 1 step 0: loss 3.5551, iter time: 567.18ms
iter 2 step 1: loss 3.8844, iter time: 796.59ms (optimizer.step)
iter 3 step 1: loss 3.7451, iter time: 76.60ms
iter 4 step 2: loss 3.3320, iter time: 469.35ms (optimizer.step)
iter 5 step 2: loss 1.1714, iter time: 73.10ms
iter 6 step 3: loss 1.4070, iter time: 442.38ms (optimizer.step)
iter 7 step 3: loss 0.4997, iter time: 70.98ms
iter 8 step 4: loss 0.3922, iter time: 446.22ms (optimizer.step)
iter 9 step 4: loss 0.6224, iter time: 69.96ms
iter 10 step 5: loss 0.5729, iter time: 446.39ms (optimizer.step)
iter 11 step 5: loss 0.4016, iter time: 70.05ms
iter 12 step 6: loss 0.5031, iter time: 436.62ms (optimizer.step)
iter 13 step 6: loss 0.1636, iter time: 72.35ms
iter 14 step 7: loss 0.1915, iter time: 440.08ms (optimizer.step)
iter 15 step 7: loss 0.1349, iter time: 73.24ms
iter 16 step 8: loss 0.1595, iter time: 444.82ms (optimizer.step)
iter 17 step 8: loss 0.0983, iter time: 71.23ms
iter 18 step 9: loss 0.1043, iter time: 444.88ms (optimizer.step)
iter 19 step 9: loss 0.0808, iter time: 69.35ms
iter 20 step 10: loss 0.0890, iter time: 442.52ms (optimizer.step)
iter 21 step 10: loss 0.0552, iter time: 71.58ms
iter 22 step 11: loss 0.0709, iter time: 443.41ms (optimizer.step)
iter 23 step 11: loss 0.0729, iter time: 71.70ms
iter 24 step 12: loss 0.0522, iter time: 442.91ms (optimizer.step)
lora.py with DDP: train loss 3.7451 after 100 steps
iter 1 step 0: loss 3.5551, iter time: 167.67ms
iter 2 step 1: loss 3.3320, iter time: 5421.28ms (optimizer.step)
iter 3 step 1: loss 3.1906, iter time: 94.70ms
iter 4 step 2: loss 3.8844, iter time: 244.89ms (optimizer.step)
iter 5 step 2: loss 3.5551, iter time: 94.09ms
iter 6 step 3: loss 3.1906, iter time: 236.35ms (optimizer.step)
iter 7 step 3: loss 3.7000, iter time: 91.63ms
iter 8 step 4: loss 3.7451, iter time: 207.49ms (optimizer.step)
iter 9 step 4: loss 3.5551, iter time: 93.52ms
iter 10 step 5: loss 3.1906, iter time: 214.52ms (optimizer.step)
iter 11 step 5: loss 3.5551, iter time: 97.39ms
iter 12 step 6: loss 3.7000, iter time: 205.80ms (optimizer.step)
iter 13 step 6: loss 3.9202, iter time: 93.07ms
iter 14 step 7: loss 3.5086, iter time: 201.63ms (optimizer.step)
iter 15 step 7: loss 3.5551, iter time: 99.50ms
iter 16 step 8: loss 3.5086, iter time: 201.73ms (optimizer.step)
iter 17 step 8: loss 3.7451, iter time: 94.45ms
iter 18 step 9: loss 3.7000, iter time: 209.46ms (optimizer.step)
iter 19 step 9: loss 3.8844, iter time: 91.71ms
iter 20 step 10: loss 3.7451, iter time: 207.51ms (optimizer.step)
iter 21 step 10: loss 3.3320, iter time: 93.41ms
iter 22 step 11: loss 3.9202, iter time: 203.45ms (optimizer.step)
iter 23 step 11: loss 3.7451, iter time: 91.74ms
iter 24 step 12: loss 3.5073, iter time: 207.43ms (optimizer.step)
iter 25 step 12: loss 3.8844, iter time: 90.37ms
iter 26 step 13: loss 3.5551, iter time: 208.81ms (optimizer.step)
iter 27 step 13: loss 3.3320, iter time: 96.21ms
iter 28 step 14: loss 3.5073, iter time: 198.56ms (optimizer.step)
iter 29 step 14: loss 3.3320, iter time: 98.13ms
iter 30 step 15: loss 3.5551, iter time: 200.53ms (optimizer.step)
iter 31 step 15: loss 3.5073, iter time: 97.73ms
iter 32 step 16: loss 3.7000, iter time: 199.27ms (optimizer.step)
iter 33 step 16: loss 3.1906, iter time: 93.94ms
iter 34 step 17: loss 3.7000, iter time: 205.37ms (optimizer.step)
iter 35 step 17: loss 3.9202, iter time: 94.18ms
iter 36 step 18: loss 3.5086, iter time: 203.05ms (optimizer.step)
iter 37 step 18: loss 3.7000, iter time: 90.91ms
iter 38 step 19: loss 3.5551, iter time: 210.29ms (optimizer.step)
iter 39 step 19: loss 3.9545, iter time: 93.97ms
iter 40 step 20: loss 3.9202, iter time: 200.84ms (optimizer.step)
iter 41 step 20: loss 3.5073, iter time: 93.15ms
iter 42 step 21: loss 3.9202, iter time: 208.65ms (optimizer.step)
iter 43 step 21: loss 3.7000, iter time: 92.60ms
iter 44 step 22: loss 3.5551, iter time: 202.50ms (optimizer.step)
iter 45 step 22: loss 3.9202, iter time: 109.24ms
iter 46 step 23: loss 3.9545, iter time: 189.96ms (optimizer.step)
iter 47 step 23: loss 3.7000, iter time: 108.17ms
iter 48 step 24: loss 3.5551, iter time: 190.84ms (optimizer.step)
iter 49 step 24: loss 3.8844, iter time: 108.57ms
iter 50 step 25: loss 3.7000, iter time: 195.16ms (optimizer.step)
iter 51 step 25: loss 3.7451, iter time: 108.96ms
iter 52 step 26: loss 3.5073, iter time: 189.34ms (optimizer.step)
iter 53 step 26: loss 3.5551, iter time: 108.77ms
iter 54 step 27: loss 3.7000, iter time: 190.42ms (optimizer.step)
iter 55 step 27: loss 3.8844, iter time: 108.43ms
iter 56 step 28: loss 3.5073, iter time: 192.47ms (optimizer.step)
iter 57 step 28: loss 3.9545, iter time: 105.81ms
iter 58 step 29: loss 3.9202, iter time: 195.24ms (optimizer.step)
iter 59 step 29: loss 3.5551, iter time: 109.79ms
iter 60 step 30: loss 3.3320, iter time: 191.49ms (optimizer.step)
iter 61 step 30: loss 3.3320, iter time: 108.14ms
iter 62 step 31: loss 3.5551, iter time: 188.55ms (optimizer.step)
iter 63 step 31: loss 3.1906, iter time: 109.10ms
iter 64 step 32: loss 3.7451, iter time: 189.91ms (optimizer.step)
iter 65 step 32: loss 3.5073, iter time: 107.88ms
iter 66 step 33: loss 3.5086, iter time: 188.76ms (optimizer.step)
iter 67 step 33: loss 3.7451, iter time: 107.70ms
iter 68 step 34: loss 3.9202, iter time: 188.97ms (optimizer.step)
iter 69 step 34: loss 3.7000, iter time: 108.00ms
iter 70 step 35: loss 3.7000, iter time: 190.73ms (optimizer.step)
iter 71 step 35: loss 3.5086, iter time: 108.37ms
iter 72 step 36: loss 3.5551, iter time: 188.15ms (optimizer.step)
iter 73 step 36: loss 3.9202, iter time: 108.60ms
iter 74 step 37: loss 3.7000, iter time: 190.43ms (optimizer.step)
iter 75 step 37: loss 3.9202, iter time: 108.32ms
iter 76 step 38: loss 3.9202, iter time: 195.27ms (optimizer.step)
iter 77 step 38: loss 3.7451, iter time: 105.94ms
iter 78 step 39: loss 3.3320, iter time: 193.23ms (optimizer.step)
iter 79 step 39: loss 3.3320, iter time: 105.46ms
iter 80 step 40: loss 3.3320, iter time: 195.81ms (optimizer.step)
iter 81 step 40: loss 3.3320, iter time: 109.03ms
iter 82 step 41: loss 3.5086, iter time: 189.86ms (optimizer.step)
iter 83 step 41: loss 3.5551, iter time: 108.57ms
iter 84 step 42: loss 3.3320, iter time: 188.29ms (optimizer.step)
iter 85 step 42: loss 3.5551, iter time: 108.68ms
iter 86 step 43: loss 3.9202, iter time: 192.64ms (optimizer.step)
iter 87 step 43: loss 3.7000, iter time: 108.35ms
iter 88 step 44: loss 3.3320, iter time: 197.48ms (optimizer.step)
iter 89 step 44: loss 3.8844, iter time: 107.49ms
iter 90 step 45: loss 3.5551, iter time: 182.41ms (optimizer.step)
iter 91 step 45: loss 3.5551, iter time: 108.68ms
iter 92 step 46: loss 3.7451, iter time: 197.15ms (optimizer.step)
iter 93 step 46: loss 3.5073, iter time: 109.28ms
iter 94 step 47: loss 3.1906, iter time: 185.06ms (optimizer.step)
iter 95 step 47: loss 3.7451, iter time: 105.93ms
iter 96 step 48: loss 3.5551, iter time: 188.47ms (optimizer.step)
iter 97 step 48: loss 3.8844, iter time: 107.85ms
iter 98 step 49: loss 3.9545, iter time: 193.28ms (optimizer.step)
iter 99 step 49: loss 3.5551, iter time: 106.31ms
iter 100 step 50: loss 3.7451, iter time: 187.97ms (optimizer.step)
iter 101 step 50: loss 3.5551, iter time: 108.78ms
iter 102 step 51: loss 3.1906, iter time: 196.97ms (optimizer.step)
iter 103 step 51: loss 3.5086, iter time: 108.74ms
iter 104 step 52: loss 3.5551, iter time: 190.11ms (optimizer.step)
iter 105 step 52: loss 3.7451, iter time: 107.93ms
iter 106 step 53: loss 3.7000, iter time: 190.80ms (optimizer.step)
iter 107 step 53: loss 3.5086, iter time: 106.93ms
iter 108 step 54: loss 3.7000, iter time: 189.49ms (optimizer.step)
iter 109 step 54: loss 3.5073, iter time: 107.82ms
iter 110 step 55: loss 3.1906, iter time: 188.76ms (optimizer.step)
iter 111 step 55: loss 3.9202, iter time: 108.66ms
iter 112 step 56: loss 3.5073, iter time: 190.35ms (optimizer.step)
iter 113 step 56: loss 3.5073, iter time: 109.09ms
iter 114 step 57: loss 3.1906, iter time: 194.27ms (optimizer.step)
iter 115 step 57: loss 3.3320, iter time: 105.29ms
iter 116 step 58: loss 3.5551, iter time: 193.18ms (optimizer.step)
iter 117 step 58: loss 3.9202, iter time: 106.85ms
iter 118 step 59: loss 3.5073, iter time: 194.52ms (optimizer.step)
iter 119 step 59: loss 3.5073, iter time: 109.06ms
iter 120 step 60: loss 3.8844, iter time: 189.87ms (optimizer.step)
iter 121 step 60: loss 3.7451, iter time: 107.30ms
iter 122 step 61: loss 3.5551, iter time: 189.06ms (optimizer.step)
iter 123 step 61: loss 3.3320, iter time: 108.87ms
iter 124 step 62: loss 3.5551, iter time: 187.82ms (optimizer.step)
iter 125 step 62: loss 3.7451, iter time: 107.67ms
iter 126 step 63: loss 3.7451, iter time: 193.35ms (optimizer.step)
iter 127 step 63: loss 3.5551, iter time: 107.17ms
iter 128 step 64: loss 3.7000, iter time: 187.17ms (optimizer.step)
iter 129 step 64: loss 3.7000, iter time: 108.06ms
iter 130 step 65: loss 3.5551, iter time: 191.12ms (optimizer.step)
iter 131 step 65: loss 3.3320, iter time: 108.52ms
iter 132 step 66: loss 3.5551, iter time: 196.97ms (optimizer.step)
iter 133 step 66: loss 3.1906, iter time: 108.26ms
iter 134 step 67: loss 3.7451, iter time: 190.88ms (optimizer.step)
iter 135 step 67: loss 3.7451, iter time: 107.66ms
iter 136 step 68: loss 3.8844, iter time: 191.31ms (optimizer.step)
iter 137 step 68: loss 3.5073, iter time: 107.27ms
iter 138 step 69: loss 3.5551, iter time: 189.66ms (optimizer.step)
iter 139 step 69: loss 3.8844, iter time: 108.03ms
iter 140 step 70: loss 3.9545, iter time: 193.45ms (optimizer.step)
iter 141 step 70: loss 3.9545, iter time: 106.10ms
iter 142 step 71: loss 3.7451, iter time: 195.23ms (optimizer.step)
iter 143 step 71: loss 3.5086, iter time: 107.01ms
iter 144 step 72: loss 3.9202, iter time: 194.32ms (optimizer.step)
iter 145 step 72: loss 3.7000, iter time: 108.96ms
iter 146 step 73: loss 3.3320, iter time: 190.23ms (optimizer.step)
iter 147 step 73: loss 3.5551, iter time: 107.08ms
iter 148 step 74: loss 3.8844, iter time: 191.91ms (optimizer.step)
iter 149 step 74: loss 3.3320, iter time: 106.69ms
iter 150 step 75: loss 3.5551, iter time: 187.99ms (optimizer.step)
iter 151 step 75: loss 3.5086, iter time: 107.73ms
iter 152 step 76: loss 3.9545, iter time: 189.14ms (optimizer.step)
iter 153 step 76: loss 3.1906, iter time: 107.78ms
iter 154 step 77: loss 3.7451, iter time: 189.06ms (optimizer.step)
iter 155 step 77: loss 3.8844, iter time: 108.37ms
iter 156 step 78: loss 3.5073, iter time: 190.49ms (optimizer.step)
iter 157 step 78: loss 3.9202, iter time: 107.38ms
iter 158 step 79: loss 3.5073, iter time: 194.00ms (optimizer.step)
iter 159 step 79: loss 3.5551, iter time: 106.31ms
iter 160 step 80: loss 3.8844, iter time: 192.46ms (optimizer.step)
iter 161 step 80: loss 3.7451, iter time: 106.95ms
iter 162 step 81: loss 3.9202, iter time: 194.22ms (optimizer.step)
iter 163 step 81: loss 3.7451, iter time: 108.14ms
iter 164 step 82: loss 3.9545, iter time: 190.68ms (optimizer.step)
iter 165 step 82: loss 3.7451, iter time: 107.95ms
iter 166 step 83: loss 3.7451, iter time: 190.76ms (optimizer.step)
iter 167 step 83: loss 3.7451, iter time: 107.49ms
iter 168 step 84: loss 3.9545, iter time: 197.79ms (optimizer.step)
iter 169 step 84: loss 3.3320, iter time: 108.76ms
iter 170 step 85: loss 3.7451, iter time: 189.47ms (optimizer.step)
iter 171 step 85: loss 3.7000, iter time: 107.84ms
iter 172 step 86: loss 3.5073, iter time: 189.00ms (optimizer.step)
iter 173 step 86: loss 3.9545, iter time: 108.78ms
iter 174 step 87: loss 3.5086, iter time: 190.17ms (optimizer.step)
iter 175 step 87: loss 3.8844, iter time: 107.21ms
iter 176 step 88: loss 3.1906, iter time: 196.45ms (optimizer.step)
iter 177 step 88: loss 3.3320, iter time: 107.93ms
iter 178 step 89: loss 3.3320, iter time: 191.34ms (optimizer.step)
iter 179 step 89: loss 3.5551, iter time: 108.91ms
iter 180 step 90: loss 3.5073, iter time: 192.33ms (optimizer.step)
iter 181 step 90: loss 3.5073, iter time: 107.72ms
iter 182 step 91: loss 3.7000, iter time: 191.23ms (optimizer.step)
iter 183 step 91: loss 3.5073, iter time: 107.90ms
iter 184 step 92: loss 3.7000, iter time: 189.10ms (optimizer.step)
iter 185 step 92: loss 3.7451, iter time: 108.30ms
iter 186 step 93: loss 3.5073, iter time: 190.84ms (optimizer.step)
iter 187 step 93: loss 3.7451, iter time: 108.35ms
iter 188 step 94: loss 3.5086, iter time: 192.90ms (optimizer.step)
iter 189 step 94: loss 3.7451, iter time: 109.55ms
iter 190 step 95: loss 3.5073, iter time: 184.05ms (optimizer.step)
iter 191 step 95: loss 3.5086, iter time: 108.47ms
iter 192 step 96: loss 3.5551, iter time: 190.71ms (optimizer.step)
iter 193 step 96: loss 3.9202, iter time: 107.94ms
iter 194 step 97: loss 3.7000, iter time: 186.67ms (optimizer.step)
iter 195 step 97: loss 3.7451, iter time: 107.74ms
iter 196 step 98: loss 3.7000, iter time: 193.63ms (optimizer.step)
iter 197 step 98: loss 3.5073, iter time: 109.04ms
iter 198 step 99: loss 3.7451, iter time: 189.90ms (optimizer.step)
iter 199 step 99: loss 3.3320, iter time: 108.26ms
iter 200 step 100: loss 3.7451, iter time: 188.62ms (optimizer.step)
Is this an expected behaviour (maybe related to the LoRa implementation)?
Can you share a branch from a fork or a diff with all the changes that you've made to the repository and the commands to repro your results?
Sure, here are the steps to reproduce on a fresh lit-gpt clone:
# download & convert model
python scripts/download.py --repo_id TinyLlama/TinyLlama-1.1B-intermediate-step-955k-token-2T
python scripts/convert_hf_checkpoint.py --checkpoint checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-955k-token-2T/
# prepare (tiny) train dataset
python scripts/prepare_alpaca.py --checkpoint_dir checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-955k-token-2T/ --test_split_fraction 0.0003865
cp data/alpaca/test.pt data/alpaca/train.pt
force DDP with finetune/lora.py, use 2 devices (note: I'm using 2 A40), and decrease batch_size (just to increase the frequency of updates given the small dataset, I've tried with different values obtaining the same behaviour):
diff --git a/finetune/lora.py b/finetune/lora.py
index 852cf01..6255c19 100644
--- a/finetune/lora.py
+++ b/finetune/lora.py
@@ -32,11 +32,11 @@ save_interval = 100
eval_iters = 100
eval_max_new_tokens = 100
log_interval = 1
-devices = 1
+devices = 2
# Hyperparameters
learning_rate = 3e-4
-batch_size = 128
+batch_size = 8
micro_batch_size = 4
gradient_accumulation_iters = batch_size // micro_batch_size
assert gradient_accumulation_iters > 0
@@ -90,6 +90,7 @@ def setup(
else:
strategy = "auto"
+ strategy="ddp"
logger = CSVLogger(out_dir.parent, out_dir.name, flush_logs_every_n_steps=log_interval)
fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger, plugins=plugins)
fabric.print(hparams)
# run LoRa training
CUDA_VISIBLE_DEVICES=1,2 python finetune/lora.py --checkpoint_dir checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-955k-token-2T/
Observe the loss that does not improve.
On the other hand, if you do:
diff --git a/finetune/full.py b/finetune/full.py
index b88d19a..35d1b6e 100644
--- a/finetune/full.py
+++ b/finetune/full.py
@@ -31,12 +31,12 @@ save_interval = 1000
eval_iters = 100
eval_max_new_tokens = 100
log_interval = 1
-devices = 1
+devices = 2
# Hyperparameters
learning_rate = 3e-3
-batch_size = 64 / devices
-micro_batch_size = 1
+batch_size = 8
+micro_batch_size = 4
gradient_accumulation_iters = batch_size // micro_batch_size
assert gradient_accumulation_iters > 0
max_seq_length = None # assign value to truncate
@@ -69,6 +69,7 @@ def setup(
else:
strategy = "auto"
+ strategy="ddp"
logger = CSVLogger(out_dir.parent, out_dir.name, flush_logs_every_n_steps=log_interval)
fabric = L.Fabric(devices=fabric_devices, strategy=strategy, precision=precision, loggers=logger)
fabric.print(hparams)
# run full training
CUDA_VISIBLE_DEVICES=1,2 python finetune/full.py --checkpoint_dir checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-955k-token-2T/
Model will overfit as expected.
Hey @v-dicicco and @carmocca, Do you have any update on how/if you were able to solve this? I have the same problem, DDP strategy gets worser performance compared to training on a single GPU