conditional-flow-matching
conditional-flow-matching copied to clipboard
add support for distributed data parallel training
This PR adds support for distributed data parallel (DDP) and replaces DataParallel with DistributedDataParallel in train_cifar.py, which can be used via the flag parallel. To achieve this, the code is refactored, and the flags master_addr and master_port are added.
I tested the changes, on a single GPU, I get an FID of 3.74 (with the OT-CFM method), on two GPUs with DDP, I get an FID of 3.81.
Before submitting
- [x] Did you make sure title is self-explanatory and the description concisely explains the PR?
- [x] Did you make sure your PR does only one thing, instead of bundling different changes together?
- [x] Did you list all the breaking changes introduced by this pull request?
- [x] Did you test your PR locally with
pytestcommand? - [x] Did you run pre-commit hooks with
pre-commit run -acommand?
Hi, thank you for your contribution!
I had an internal implementation with fabric form litghtning but I like to rely only on PyTorch for this example. I need some time to review it (a few days/weeks). I will come back to it soon.
I like the new changes. @atong01 do you mind having a look? I also think it would be great to keep the original train_cifar10.py.
While I like this code, it is slightly more complicated than the previous one. So I would keep both. The idea of this package is that any master student can easily understand it in 1hour. @ImahnShekhzadeh can you rename this file train_cifar10_ddp.py please? and re-add the previous file? Thanks
@ImahnShekhzadeh can you rename this file train_cifar10_ddp.py please? and re-add the previous file? Thanks
Done
LGTM. Thanks for the contribution @ImahnShekhzadeh