torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

[Flux] Incorrect loss after loading from checkpoint

Open CarlosGomes98 opened this issue 7 months ago • 6 comments

Bug description

After loading from checkpoint, the loss spikes and then returns to expected values after a few steps.

To repo, run a first training, storing a checkpoint, and then a second one loading from the same dump_folder.

First command:

NGPU=1 torchtitan/experiments/flux/run_train.sh --training.dataset_path= --training.dataset=cc12m-wds --checkpoint.enable_checkpoint --checkpoint.interval=5 --job.dump_folder=test_debug_ckpt --training.deterministic --training.seed=42 --eval.eval_freq=20 --training.steps=10 

Log at step 6: step: 6 loss: 2.4664 memory: 38.28GiB(80.80%) tps: 96,480 tflops: 0.00 mfu: 0.00%

Second command adds --checkpoint.load_step=5 Log at step 6:

[rank0]:[titan] 2025-05-21 16:49:42,812 - root - INFO - Loading the checkpoint at step 5.
[rank0]:[titan] 2025-05-21 16:49:47,235 - root - INFO - [GC] GC collection for checkpoint loading. 0.01 seconds.
[rank0]:[titan] 2025-05-21 16:49:47,235 - root - INFO - Finished loading the checkpoint in 4.42 seconds.
[rank0]:[titan] 2025-05-21 16:49:47,235 - root - INFO - Training starts at step 6.
[rank0]:[titan] 2025-05-21 16:49:48,549 - root - INFO - step:  6  loss:  2.1726  memory: 37.98GiB(80.18%)  tps: 14,593  tflops: 0.00  mfu: 0.00%

@wwwjn

Versions

Trained using the flux-train branch, commit fa5c2012f68330022fa57e8fa1f5c68b135e9533

CarlosGomes98 avatar May 21 '25 14:05 CarlosGomes98

@wwwjn Can this be related to the rng state not being saved?

fegin avatar May 21 '25 16:05 fegin

Yes, this is related to #1194

wwwjn avatar May 21 '25 17:05 wwwjn

This might be a factor, but I would not expect the loss to spike so much because of this. The spike seems to suggest some more important component is not being correctly loaded / saved (e.g. optimizer parameters?)

CarlosGomes98 avatar May 21 '25 17:05 CarlosGomes98

@CarlosGomes98 I did some test before #1195 get merged. There are several source of un-deterministic in loss (ideally, we should be able to reproduce the loss with --training.determinisitc enabled):

  1. Set random seed on each rank. You could get rid of this randomness by enabling --training.determinisitc .
  2. Dataloader's order. I observed before the dataloader might return random order of data samples when streaming from Huggingface. You can get rid of this randomness by using a local dataset (just specifying a local dir)
  3. RNG states are not saved and loaded from checkpoint #1194, and we are trying to solve this now. It matters for Flux model because we generate random noise as part of input. This is not affecting large scale training because we need randomized noise. You can get rid of this by calling set_deterministic() (link) again at step 6 for both runs, to make rng states identical between runs at step=6.

I think in your case, the loss difference comes from 2 & 3 . Once you remove randomness in 2 and 3, you could be able to reproduce the training loss curve in #1195

wwwjn avatar May 21 '25 18:05 wwwjn

I also dig into all the states (optimizer, model weights, dataloader, lr_scheduler) and calculated hash before saving and after loading at step6. Here's an example how I calculated the hash:

  1. For llama model: https://github.com/pytorch/torchtitan/blob/flux-checkpointing/torchtitan/train.py#L377-L410.
  2. For flux model: https://github.com/pytorch/torchtitan/blob/flux-checkpointing/torchtitan/experiments/flux/train.py#L189-L225

The saved and loaded optimizer states are exactly the same, and this is easy to reproduce by using the code from the previous link

wwwjn avatar May 21 '25 18:05 wwwjn

Add more commands offline:

RNG states change is deteministic - since we call set_deterministic() every time we initialize a trainer.

Run1: without checkpoint load/save

  • [On rank 0] We set seed = 0 at step 0 when initialize the FluxTrainer. The RNG states is initial state x
  • Then we train 5 steps and called RNG t times, and RNG state is y at step 6. The random noise will come from RNG state y

Run 2: load checkpoint from step 5

  • We set seed = 0 at when initialize the FluxTrainer, and load states at step 5. Now the RNG is not called yet, so the RNG states will be state x
  • At step 6 training, the random noise will come from RNG state x

wwwjn avatar May 21 '25 18:05 wwwjn

I think I may have found the reason for this. For the last step in each training job, the loss seems to be incorrect, at least in the plot. The way I added this to the plot was {f"hparams/lr_{i}": scheduler.get_last_lr()[0] for i, scheduler in enumerate(self.lr_schedulers)}. See the tb plot below. @wwwjn can we reopen to look into this?

Image

CarlosGomes98 avatar May 28 '25 06:05 CarlosGomes98

@wwwjn can you take a look and confirm Flux checkpoint is actually working? Thanks!

fegin avatar May 28 '25 23:05 fegin

I think I may have found the reason for this. For the last step in each training job, the loss seems to be incorrect, at least in the plot. The way I added this to the plot was {f"hparams/lr_{i}": scheduler.get_last_lr()[0] for i, scheduler in enumerate(self.lr_schedulers)}. See the tb plot below. @wwwjn can we reopen to look into this?

Thanks for flagging! From the plot you showed, it seems that the learning rate drops at the last step, not the loss you mentioned. If so, this could be a bug in lr_scheduler.

And how do you notice the loss is incorrect at last step? Could you provide more details, eg, which model did you see the abnormal loss at last step, how many steps did you run?

And a side note here: flux-train is a stale branch and it's behind main branch. The most up-to-date FLUX changes are on main branch, could you also try to run with main branch ? @CarlosGomes98

wwwjn avatar May 29 '25 11:05 wwwjn

@wwwjn @fegin I believe I've found the root cause. It seems to be a bug in the lr_scheduler implementation, specifically the + 1 adjustments that happen.

Lets set an example scenario where our lr schedule has warmup_steps of 5 and a target lr of 0.0008. We would expect the lr to increase in each step by 0.0008/5 = 0.00016. However, what happens is below:

[rank0]:[titan] 2025-06-04 14:42:54,945 - root - INFO - Step 1 lr: 0.00013333333333333334
[rank0]:[titan] 2025-06-04 14:42:54,946 - root - INFO - step:  1  loss:  1.7905  memory: 19.94GiB(42.08%)  tps: 519,779  tflops: 0.00  mfu: 0.00%
[rank0]:[titan] 2025-06-04 14:42:54,946 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-06-04 14:42:55,316 - root - INFO - Step 2 lr: 0.0002666666666666667
[rank0]:[titan] 2025-06-04 14:42:55,316 - root - INFO - step:  2  loss:  1.9953  memory: 20.45GiB(43.17%)  tps: 2,121,432  tflops: 0.00  mfu: 0.00%
[rank0]:[titan] 2025-06-04 14:42:55,686 - root - INFO - Step 3 lr: 0.0004
[rank0]:[titan] 2025-06-04 14:42:55,686 - root - INFO - step:  3  loss:  1.7081  memory: 20.45GiB(43.17%)  tps: 2,125,984  tflops: 0.00  mfu: 0.00%
[rank0]:[titan] 2025-06-04 14:42:56,079 - root - INFO - Step 4 lr: 0.0005333333333333334
[rank0]:[titan] 2025-06-04 14:42:56,079 - root - INFO - step:  4  loss:  1.6267  memory: 20.45GiB(43.17%)  tps: 2,004,664  tflops: 0.00  mfu: 0.00%
[rank0]:[titan] 2025-06-04 14:42:56,476 - root - INFO - Step 5 lr: 0.0008
[rank0]:[titan] 2025-06-04 14:42:56,477 - root - INFO - step:  5  loss:  1.5076  memory: 20.45GiB(43.17%)  tps: 1,977,805  tflops: 0.00  mfu: 0.00%

This is clearly wrong, as the lr is increasing by 0.00013 except for the last step, which increases by the double.

We can easily resolve this by removing the assumptions in the comments # 0-indexed step, hence + 1 adjustments. Making those changes, we get the desired behaviour of:

[rank0]:[titan] 2025-06-04 14:46:48,496 - root - INFO - Step 1 lr: 0.00016
[rank0]:[titan] 2025-06-04 14:46:48,497 - root - INFO - step:  1  loss:  1.8309  memory: 19.94GiB(42.08%)  tps: 512,095  tflops: 0.00  mfu: 0.00%
[rank0]:[titan] 2025-06-04 14:46:48,497 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-06-04 14:46:48,867 - root - INFO - Step 2 lr: 0.00032
[rank0]:[titan] 2025-06-04 14:46:48,867 - root - INFO - step:  2  loss:  1.9606  memory: 20.45GiB(43.17%)  tps: 2,124,707  tflops: 0.00  mfu: 0.00%
[rank0]:[titan] 2025-06-04 14:46:49,233 - root - INFO - Step 3 lr: 0.00048
[rank0]:[titan] 2025-06-04 14:46:49,233 - root - INFO - step:  3  loss:  1.6091  memory: 20.45GiB(43.17%)  tps: 2,147,417  tflops: 0.00  mfu: 0.00%
[rank0]:[titan] 2025-06-04 14:46:49,640 - root - INFO - Step 4 lr: 0.00064
[rank0]:[titan] 2025-06-04 14:46:49,641 - root - INFO - step:  4  loss:  1.4623  memory: 20.45GiB(43.17%)  tps: 1,931,065  tflops: 0.00  mfu: 0.00%
[rank0]:[titan] 2025-06-04 14:46:50,039 - root - INFO - Step 5 lr: 0.0008
[rank0]:[titan] 2025-06-04 14:46:50,040 - root - INFO - step:  5  loss:  1.6065  memory: 20.45GiB(43.17%)  tps: 1,971,288  tflops: 0.00  mfu: 0.00%

This same behaviour was causing the bug I described above. It was making it so, in the final step, we were going 1 over the max_steps expected, and thus setting the multiplicative lr adjustment factor to 0.

I'll submit a PR with a fix, it should be straight forward

CarlosGomes98 avatar Jun 04 '25 12:06 CarlosGomes98

We would expect the lr to increase in each step by 0.0008/5 = 0.00016

I tried to reproduce with main branch. Here's my setting:

[optimizer]
name = "AdamW"
lr = 8e-4
eps = 1e-8

[lr_scheduler]
warmup_steps = 5  # lr scheduler warm up, normally 20% of the train steps
decay_ratio = 0.0

Here is my lr results:

[rank0]:[titan] 2025-06-11 15:19:25,547 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-06-11 15:19:25,547 - root - INFO - Step 1, Learning rate: 0.00013333333333333334
[rank0]:[titan] 2025-06-11 15:19:26,917 - root - INFO - step:  1  loss:  8.2204  memory:  1.23GiB(1.55%)  tps: 9,507  tflops: 0.68  mfu: 0.22%
[rank0]:[titan] 2025-06-11 15:19:26,918 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-06-11 15:19:26,918 - root - INFO - Step 2, Learning rate: 0.0002666666666666667
[rank0]:[titan] 2025-06-11 15:19:26,997 - root - INFO - step:  2  loss:  8.1997  memory:  1.23GiB(1.56%)  tps: 206,210  tflops: 14.83  mfu: 4.75%
[rank0]:[titan] 2025-06-11 15:19:26,998 - root - INFO - Step 3, Learning rate: 0.0004
[rank0]:[titan] 2025-06-11 15:19:27,077 - root - INFO - step:  3  loss:  8.1729  memory:  1.23GiB(1.56%)  tps: 207,709  tflops: 14.94  mfu: 4.79%
[rank0]:[titan] 2025-06-11 15:19:27,077 - root - INFO - Step 4, Learning rate: 0.0005333333333333334
[rank0]:[titan] 2025-06-11 15:19:27,173 - root - INFO - step:  4  loss:  8.1171  memory:  1.23GiB(1.56%)  tps: 171,954  tflops: 12.36  mfu: 3.96%
[rank0]:[titan] 2025-06-11 15:19:27,173 - root - INFO - Step 5, Learning rate: 0.0006666666666666668
[rank0]:[titan] 2025-06-11 15:19:27,253 - root - INFO - step:  5  loss:  8.0356  memory:  1.23GiB(1.56%)  tps: 204,498  tflops: 14.70  mfu: 4.71%
[rank0]:[titan] 2025-06-11 15:19:27,253 - root - INFO - Step 6, Learning rate: 0.0008
[rank0]:[titan] 2025-06-11 15:19:27,332 - root - INFO - step:  6  loss:  7.9143  memory:  1.23GiB(1.56%)  tps: 210,034  tflops: 15.10  mfu: 4.84%
[rank0]:[titan] 2025-06-11 15:19:27,332 - root - INFO - Step 7, Learning rate: 0.0008
[rank0]:[titan] 2025-06-11 15:19:27,418 - root - INFO - step:  7  loss:  7.7103  memory:  1.23GiB(1.56%)  tps: 191,524  tflops: 13.77  mfu: 4.41%
[rank0]:[titan] 2025-06-11 15:19:27,418 - root - INFO - Step 8, Learning rate: 0.0008
[rank0]:[titan] 2025-06-11 15:19:27,500 - root - INFO - step:  8  loss:  7.4662  memory:  1.23GiB(1.56%)  tps: 200,762  tflops: 14.44  mfu: 4.63%
[rank0]:[titan] 2025-06-11 15:19:27,500 - root - INFO - Step 9, Learning rate: 0.0008
[rank0]:[titan] 2025-06-11 15:19:27,579 - root - INFO - step:  9  loss:  7.1613  memory:  1.23GiB(1.56%)  tps: 208,261  tflops: 14.97  mfu: 4.80%
[rank0]:[titan] 2025-06-11 15:19:27,579 - root - INFO - Step 10, Learning rate: 0.0008
[rank0]:[titan] 2025-06-11 15:19:27,660 - root - INFO - step: 10  loss:  7.0049  memory:  1.23GiB(1.56%)  tps: 203,108  tflops: 14.60  mfu: 4.68%

@CarlosGomes98 could you confirm does this test using the same setup with you? Thanks !

wwwjn avatar Jun 11 '25 22:06 wwwjn