conditional-flow-matching icon indicating copy to clipboard operation
conditional-flow-matching copied to clipboard

add support for distributed data parallel training

Open ImahnShekhzadeh opened this issue 1 year ago • 3 comments

This PR adds support for distributed data parallel (DDP) and replaces DataParallel with DistributedDataParallel in train_cifar.py, which can be used via the flag parallel. To achieve this, the code is refactored, and the flags master_addr and master_port are added.

I tested the changes, on a single GPU, I get an FID of 3.74 (with the OT-CFM method), on two GPUs with DDP, I get an FID of 3.81.

Before submitting

  • [x] Did you make sure title is self-explanatory and the description concisely explains the PR?
  • [x] Did you make sure your PR does only one thing, instead of bundling different changes together?
  • [x] Did you list all the breaking changes introduced by this pull request?
  • [x] Did you test your PR locally with pytest command?
  • [x] Did you run pre-commit hooks with pre-commit run -a command?

ImahnShekhzadeh avatar May 21 '24 10:05 ImahnShekhzadeh

Hi, thank you for your contribution!

I had an internal implementation with fabric form litghtning but I like to rely only on PyTorch for this example. I need some time to review it (a few days/weeks). I will come back to it soon.

kilianFatras avatar May 21 '24 14:05 kilianFatras

I like the new changes. @atong01 do you mind having a look? I also think it would be great to keep the original train_cifar10.py.

While I like this code, it is slightly more complicated than the previous one. So I would keep both. The idea of this package is that any master student can easily understand it in 1hour. @ImahnShekhzadeh can you rename this file train_cifar10_ddp.py please? and re-add the previous file? Thanks

kilianFatras avatar Aug 02 '24 21:08 kilianFatras

@ImahnShekhzadeh can you rename this file train_cifar10_ddp.py please? and re-add the previous file? Thanks

Done

ImahnShekhzadeh avatar Aug 08 '24 08:08 ImahnShekhzadeh

LGTM. Thanks for the contribution @ImahnShekhzadeh

atong01 avatar Aug 21 '24 18:08 atong01