torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

[Flux] Flux Issue Tracking

Open wwwjn opened this issue 8 months ago • 2 comments

Create a centralized tracking of all the enhancements and TODOs for Flux model on torchtitan.

Current status

  1. Implemented Flux model and integrated with torchtitan architecture.
  2. Enabled FSDP, training on 8 nodes.

TODOs

Features

  • [ ] Calculate nflops_per_token for Flux model

  • [ ] Validation:

    • [ ] Introduce a validation loop and validate metric
    • [ ] Adapted the existing inference code to work with batches (Set up a validation set, and Run inference on all samples from validation sets)
    • [ ] Calculate loss with fixed t value / noise level. For each validation sample, they calculate the loss at a few values of t, and then average it. Similar issue: #1150
    • [ ] FID and CLIP scores
  • [x] Flux Dataset & Dataloader:

    • [x] Modifying the logic of fluxdataset checkpoint loading and sample_idx so that it works with non-infinite datasets (e.g. for validation). Essentially this sample_idx should reset once the dataset is done
    • [ ] Dataloader fast save / load following https://github.com/pytorch/torchtitan/pull/1082
  • [x] Preprocessing

    • [x] Create a script and dataloader to save and read the preprocessed data, load it into training.
  • [ ] CI for torchtitan

    • [ ] Download some small size dataset since the CI environment
  • [ ] Checkpoint Format Conversion

    • [ ] Loading weights from HF model weights
    • [ ] Turn DCP format into HF model weights

Minor Bugs & Fix

  • Dtype mismatch: #1137
  • Setting the sequence length of flux-schnell to be 256: #1146

wwwjn avatar Apr 28 '25 20:04 wwwjn

cc @CarlosGomes98 @tianyu-l , here's a centralized tracker of Flux issue and next steps.

wwwjn avatar Apr 28 '25 20:04 wwwjn

Preprocessing code is here: https://github.com/pytorch/torchtitan/tree/flux-train. The preprocessed data will take huge storge, because the generated t5 encoding for each sample is 256 * 4096.

wwwjn avatar May 05 '25 21:05 wwwjn