[Flux] Incorrect loss after loading from checkpoint
Bug description
After loading from checkpoint, the loss spikes and then returns to expected values after a few steps.
To repo, run a first training, storing a checkpoint, and then a second one loading from the same dump_folder.
First command:
NGPU=1 torchtitan/experiments/flux/run_train.sh --training.dataset_path= --training.dataset=cc12m-wds --checkpoint.enable_checkpoint --checkpoint.interval=5 --job.dump_folder=test_debug_ckpt --training.deterministic --training.seed=42 --eval.eval_freq=20 --training.steps=10
Log at step 6:
step: 6 loss: 2.4664 memory: 38.28GiB(80.80%) tps: 96,480 tflops: 0.00 mfu: 0.00%
Second command adds --checkpoint.load_step=5
Log at step 6:
[rank0]:[titan] 2025-05-21 16:49:42,812 - root - INFO - Loading the checkpoint at step 5.
[rank0]:[titan] 2025-05-21 16:49:47,235 - root - INFO - [GC] GC collection for checkpoint loading. 0.01 seconds.
[rank0]:[titan] 2025-05-21 16:49:47,235 - root - INFO - Finished loading the checkpoint in 4.42 seconds.
[rank0]:[titan] 2025-05-21 16:49:47,235 - root - INFO - Training starts at step 6.
[rank0]:[titan] 2025-05-21 16:49:48,549 - root - INFO - step: 6 loss: 2.1726 memory: 37.98GiB(80.18%) tps: 14,593 tflops: 0.00 mfu: 0.00%
@wwwjn
Versions
Trained using the flux-train branch, commit fa5c2012f68330022fa57e8fa1f5c68b135e9533
@wwwjn Can this be related to the rng state not being saved?
Yes, this is related to #1194
This might be a factor, but I would not expect the loss to spike so much because of this. The spike seems to suggest some more important component is not being correctly loaded / saved (e.g. optimizer parameters?)
@CarlosGomes98 I did some test before #1195 get merged. There are several source of un-deterministic in loss (ideally, we should be able to reproduce the loss with --training.determinisitc enabled):
- Set random seed on each rank. You could get rid of this randomness by enabling
--training.determinisitc. - Dataloader's order. I observed before the dataloader might return random order of data samples when streaming from Huggingface. You can get rid of this randomness by using a local dataset (just specifying a local dir)
- RNG states are not saved and loaded from checkpoint #1194, and we are trying to solve this now. It matters for Flux model because we generate random noise as part of input. This is not affecting large scale training because we need randomized noise. You can get rid of this by calling
set_deterministic()(link) again at step 6 for both runs, to make rng states identical between runs at step=6.
I think in your case, the loss difference comes from 2 & 3 . Once you remove randomness in 2 and 3, you could be able to reproduce the training loss curve in #1195
I also dig into all the states (optimizer, model weights, dataloader, lr_scheduler) and calculated hash before saving and after loading at step6. Here's an example how I calculated the hash:
- For llama model: https://github.com/pytorch/torchtitan/blob/flux-checkpointing/torchtitan/train.py#L377-L410.
- For flux model: https://github.com/pytorch/torchtitan/blob/flux-checkpointing/torchtitan/experiments/flux/train.py#L189-L225
The saved and loaded optimizer states are exactly the same, and this is easy to reproduce by using the code from the previous link
Add more commands offline:
RNG states change is deteministic - since we call set_deterministic() every time we initialize a trainer.
Run1: without checkpoint load/save
- [On rank 0] We set seed = 0 at step 0 when initialize the FluxTrainer. The RNG states is initial state
x - Then we train 5 steps and called RNG
ttimes, and RNG state isyat step 6. The random noise will come from RNG statey
Run 2: load checkpoint from step 5
- We set seed = 0 at when initialize the FluxTrainer, and load states at step 5. Now the RNG is not called yet, so the RNG states will be state
x - At step 6 training, the random noise will come from RNG state
x
I think I may have found the reason for this. For the last step in each training job, the loss seems to be incorrect, at least in the plot. The way I added this to the plot was {f"hparams/lr_{i}": scheduler.get_last_lr()[0] for i, scheduler in enumerate(self.lr_schedulers)}. See the tb plot below. @wwwjn can we reopen to look into this?
@wwwjn can you take a look and confirm Flux checkpoint is actually working? Thanks!
I think I may have found the reason for this. For the last step in each training job, the loss seems to be incorrect, at least in the plot. The way I added this to the plot was
{f"hparams/lr_{i}": scheduler.get_last_lr()[0] for i, scheduler in enumerate(self.lr_schedulers)}. See the tb plot below. @wwwjn can we reopen to look into this?
Thanks for flagging! From the plot you showed, it seems that the learning rate drops at the last step, not the loss you mentioned. If so, this could be a bug in lr_scheduler.
And how do you notice the loss is incorrect at last step? Could you provide more details, eg, which model did you see the abnormal loss at last step, how many steps did you run?
And a side note here: flux-train is a stale branch and it's behind main branch. The most up-to-date FLUX changes are on main branch, could you also try to run with main branch ? @CarlosGomes98
@wwwjn @fegin I believe I've found the root cause. It seems to be a bug in the lr_scheduler implementation, specifically the + 1 adjustments that happen.
Lets set an example scenario where our lr schedule has warmup_steps of 5 and a target lr of 0.0008.
We would expect the lr to increase in each step by 0.0008/5 = 0.00016. However, what happens is below:
[rank0]:[titan] 2025-06-04 14:42:54,945 - root - INFO - Step 1 lr: 0.00013333333333333334
[rank0]:[titan] 2025-06-04 14:42:54,946 - root - INFO - step: 1 loss: 1.7905 memory: 19.94GiB(42.08%) tps: 519,779 tflops: 0.00 mfu: 0.00%
[rank0]:[titan] 2025-06-04 14:42:54,946 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-06-04 14:42:55,316 - root - INFO - Step 2 lr: 0.0002666666666666667
[rank0]:[titan] 2025-06-04 14:42:55,316 - root - INFO - step: 2 loss: 1.9953 memory: 20.45GiB(43.17%) tps: 2,121,432 tflops: 0.00 mfu: 0.00%
[rank0]:[titan] 2025-06-04 14:42:55,686 - root - INFO - Step 3 lr: 0.0004
[rank0]:[titan] 2025-06-04 14:42:55,686 - root - INFO - step: 3 loss: 1.7081 memory: 20.45GiB(43.17%) tps: 2,125,984 tflops: 0.00 mfu: 0.00%
[rank0]:[titan] 2025-06-04 14:42:56,079 - root - INFO - Step 4 lr: 0.0005333333333333334
[rank0]:[titan] 2025-06-04 14:42:56,079 - root - INFO - step: 4 loss: 1.6267 memory: 20.45GiB(43.17%) tps: 2,004,664 tflops: 0.00 mfu: 0.00%
[rank0]:[titan] 2025-06-04 14:42:56,476 - root - INFO - Step 5 lr: 0.0008
[rank0]:[titan] 2025-06-04 14:42:56,477 - root - INFO - step: 5 loss: 1.5076 memory: 20.45GiB(43.17%) tps: 1,977,805 tflops: 0.00 mfu: 0.00%
This is clearly wrong, as the lr is increasing by 0.00013 except for the last step, which increases by the double.
We can easily resolve this by removing the assumptions in the comments # 0-indexed step, hence + 1 adjustments.
Making those changes, we get the desired behaviour of:
[rank0]:[titan] 2025-06-04 14:46:48,496 - root - INFO - Step 1 lr: 0.00016
[rank0]:[titan] 2025-06-04 14:46:48,497 - root - INFO - step: 1 loss: 1.8309 memory: 19.94GiB(42.08%) tps: 512,095 tflops: 0.00 mfu: 0.00%
[rank0]:[titan] 2025-06-04 14:46:48,497 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-06-04 14:46:48,867 - root - INFO - Step 2 lr: 0.00032
[rank0]:[titan] 2025-06-04 14:46:48,867 - root - INFO - step: 2 loss: 1.9606 memory: 20.45GiB(43.17%) tps: 2,124,707 tflops: 0.00 mfu: 0.00%
[rank0]:[titan] 2025-06-04 14:46:49,233 - root - INFO - Step 3 lr: 0.00048
[rank0]:[titan] 2025-06-04 14:46:49,233 - root - INFO - step: 3 loss: 1.6091 memory: 20.45GiB(43.17%) tps: 2,147,417 tflops: 0.00 mfu: 0.00%
[rank0]:[titan] 2025-06-04 14:46:49,640 - root - INFO - Step 4 lr: 0.00064
[rank0]:[titan] 2025-06-04 14:46:49,641 - root - INFO - step: 4 loss: 1.4623 memory: 20.45GiB(43.17%) tps: 1,931,065 tflops: 0.00 mfu: 0.00%
[rank0]:[titan] 2025-06-04 14:46:50,039 - root - INFO - Step 5 lr: 0.0008
[rank0]:[titan] 2025-06-04 14:46:50,040 - root - INFO - step: 5 loss: 1.6065 memory: 20.45GiB(43.17%) tps: 1,971,288 tflops: 0.00 mfu: 0.00%
This same behaviour was causing the bug I described above. It was making it so, in the final step, we were going 1 over the max_steps expected, and thus setting the multiplicative lr adjustment factor to 0.
I'll submit a PR with a fix, it should be straight forward
We would expect the lr to increase in each step by 0.0008/5 = 0.00016
I tried to reproduce with main branch. Here's my setting:
[optimizer]
name = "AdamW"
lr = 8e-4
eps = 1e-8
[lr_scheduler]
warmup_steps = 5 # lr scheduler warm up, normally 20% of the train steps
decay_ratio = 0.0
Here is my lr results:
[rank0]:[titan] 2025-06-11 15:19:25,547 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-06-11 15:19:25,547 - root - INFO - Step 1, Learning rate: 0.00013333333333333334
[rank0]:[titan] 2025-06-11 15:19:26,917 - root - INFO - step: 1 loss: 8.2204 memory: 1.23GiB(1.55%) tps: 9,507 tflops: 0.68 mfu: 0.22%
[rank0]:[titan] 2025-06-11 15:19:26,918 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-06-11 15:19:26,918 - root - INFO - Step 2, Learning rate: 0.0002666666666666667
[rank0]:[titan] 2025-06-11 15:19:26,997 - root - INFO - step: 2 loss: 8.1997 memory: 1.23GiB(1.56%) tps: 206,210 tflops: 14.83 mfu: 4.75%
[rank0]:[titan] 2025-06-11 15:19:26,998 - root - INFO - Step 3, Learning rate: 0.0004
[rank0]:[titan] 2025-06-11 15:19:27,077 - root - INFO - step: 3 loss: 8.1729 memory: 1.23GiB(1.56%) tps: 207,709 tflops: 14.94 mfu: 4.79%
[rank0]:[titan] 2025-06-11 15:19:27,077 - root - INFO - Step 4, Learning rate: 0.0005333333333333334
[rank0]:[titan] 2025-06-11 15:19:27,173 - root - INFO - step: 4 loss: 8.1171 memory: 1.23GiB(1.56%) tps: 171,954 tflops: 12.36 mfu: 3.96%
[rank0]:[titan] 2025-06-11 15:19:27,173 - root - INFO - Step 5, Learning rate: 0.0006666666666666668
[rank0]:[titan] 2025-06-11 15:19:27,253 - root - INFO - step: 5 loss: 8.0356 memory: 1.23GiB(1.56%) tps: 204,498 tflops: 14.70 mfu: 4.71%
[rank0]:[titan] 2025-06-11 15:19:27,253 - root - INFO - Step 6, Learning rate: 0.0008
[rank0]:[titan] 2025-06-11 15:19:27,332 - root - INFO - step: 6 loss: 7.9143 memory: 1.23GiB(1.56%) tps: 210,034 tflops: 15.10 mfu: 4.84%
[rank0]:[titan] 2025-06-11 15:19:27,332 - root - INFO - Step 7, Learning rate: 0.0008
[rank0]:[titan] 2025-06-11 15:19:27,418 - root - INFO - step: 7 loss: 7.7103 memory: 1.23GiB(1.56%) tps: 191,524 tflops: 13.77 mfu: 4.41%
[rank0]:[titan] 2025-06-11 15:19:27,418 - root - INFO - Step 8, Learning rate: 0.0008
[rank0]:[titan] 2025-06-11 15:19:27,500 - root - INFO - step: 8 loss: 7.4662 memory: 1.23GiB(1.56%) tps: 200,762 tflops: 14.44 mfu: 4.63%
[rank0]:[titan] 2025-06-11 15:19:27,500 - root - INFO - Step 9, Learning rate: 0.0008
[rank0]:[titan] 2025-06-11 15:19:27,579 - root - INFO - step: 9 loss: 7.1613 memory: 1.23GiB(1.56%) tps: 208,261 tflops: 14.97 mfu: 4.80%
[rank0]:[titan] 2025-06-11 15:19:27,579 - root - INFO - Step 10, Learning rate: 0.0008
[rank0]:[titan] 2025-06-11 15:19:27,660 - root - INFO - step: 10 loss: 7.0049 memory: 1.23GiB(1.56%) tps: 203,108 tflops: 14.60 mfu: 4.68%
@CarlosGomes98 could you confirm does this test using the same setup with you? Thanks !