torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

[Flux] Test and enable checkpointing for Flux model

Open wwwjn opened this issue 7 months ago • 2 comments

Context

  1. Enabled and tested checkpointing in Flux model
  2. Change Flux model last layer reshare_after_forward to be True. This is to avoid checkpointing error after forward path in evaluation loop.
  3. Fix minor issue of HSDP: dp_mesh_dim_names = ("dp_replicate", "dp_shard")

Test

Observation:

  1. The dataloader (hugging face dataset loader) behavior is not deterministic, which caused slightly difference in the loss curve, when the dataset is downloaded on the flight. 2.[Debugging] When save and then load the checkpoint at step=100, the training curve is not the same (with training.deterministic = True) Screenshot 2025-05-07 at 2 27 58 PM

By tuning off classifier-free guidance(in dataloader), eval steps and load from downloaded dataset, the hash of each batch is identical across different runs. The issue is around deterministic algorithm

wwwjn avatar May 06 '25 04:05 wwwjn

By tuning off classifier-free guidance(in dataloader), eval steps and load from downloaded dataset, the hash of each batch is identical across different runs. The issue is around deterministic algorithm

Is the loss curve with checkpoint the same if we turn off all the features you mentioned here?

I figured out later it's because the image encoding from auto encoder is not deterministic, it might not related to checkpoint save/load.

wwwjn avatar May 08 '25 17:05 wwwjn

But how come without checkpointing the loss curve matches (slightly off) but with checkpointing the loss curve is way off?

fegin avatar May 08 '25 18:05 fegin