[Flux] Test and enable checkpointing for Flux model
Context
- Enabled and tested checkpointing in Flux model
- Change Flux model last layer reshare_after_forward to be True. This is to avoid checkpointing error after forward path in evaluation loop.
- Fix minor issue of HSDP:
dp_mesh_dim_names = ("dp_replicate", "dp_shard")
Test
Observation:
- The dataloader (hugging face dataset loader) behavior is not deterministic, which caused slightly difference in the loss curve, when the dataset is downloaded on the flight.
2.[Debugging] When save and then load the checkpoint at step=100, the training curve is not the same (with
training.deterministic = True)
By tuning off classifier-free guidance(in dataloader), eval steps and load from downloaded dataset, the hash of each batch is identical across different runs. The issue is around deterministic algorithm
By tuning off classifier-free guidance(in dataloader), eval steps and load from downloaded dataset, the hash of each batch is identical across different runs. The issue is around deterministic algorithm
Is the loss curve with checkpoint the same if we turn off all the features you mentioned here?
I figured out later it's because the image encoding from auto encoder is not deterministic, it might not related to checkpoint save/load.
But how come without checkpointing the loss curve matches (slightly off) but with checkpointing the loss curve is way off?