torchtitan
torchtitan copied to clipboard
[Flux] Flux Issue Tracking
Create a centralized tracking of all the enhancements and TODOs for Flux model on torchtitan.
Current status
- Implemented Flux model and integrated with torchtitan architecture.
- Enabled FSDP, training on 8 nodes.
TODOs
Features
-
[ ] Calculate nflops_per_token for Flux model
-
[ ] Validation:
- [ ] Introduce a validation loop and validate metric
- [ ] Adapted the existing inference code to work with batches (Set up a validation set, and Run inference on all samples from validation sets)
- [ ] Calculate loss with fixed t value / noise level. For each validation sample, they calculate the loss at a few values of t, and then average it. Similar issue: #1150
- [ ] FID and CLIP scores
-
[x] Flux Dataset & Dataloader:
- [x] Modifying the logic of fluxdataset checkpoint loading and sample_idx so that it works with non-infinite datasets (e.g. for validation). Essentially this sample_idx should reset once the dataset is done
- [ ] Dataloader fast save / load following https://github.com/pytorch/torchtitan/pull/1082
-
[x] Preprocessing
- [x] Create a script and dataloader to save and read the preprocessed data, load it into training.
-
[ ] CI for torchtitan
- [ ] Download some small size dataset since the CI environment
-
[ ] Checkpoint Format Conversion
- [ ] Loading weights from HF model weights
- [ ] Turn DCP format into HF model weights
Minor Bugs & Fix
- Dtype mismatch: #1137
- Setting the sequence length of flux-schnell to be 256: #1146
cc @CarlosGomes98 @tianyu-l , here's a centralized tracker of Flux issue and next steps.
Preprocessing code is here: https://github.com/pytorch/torchtitan/tree/flux-train. The preprocessed data will take huge storge, because the generated t5 encoding for each sample is 256 * 4096.