pytorch-meta icon indicating copy to clipboard operation
pytorch-meta copied to clipboard

Using distributed training with torch meta (e.g. DDP, distributed RPC)

Open brando90 opened this issue 4 years ago • 19 comments

I've noticed that my training take a long to train, probably because when using higher one has to loop through each batch individually as in

https://github.com/tristandeleu/pytorch-meta/blob/389e35ef9aa812f07ce50a3f3bd253c4efb9765c/examples/maml-higher/train.py#L113

I noticed that one parallelize this part of the code but since torchmeta doesn't work with DDP (or at least it's unclear if it does) I decided that perhaps that using distributed RPC might be a good way to do it - though it's not been as trivial as I thought. Thus, I was wondering if this path was worth it if anyone has tried it or perhaps someone with more experience with distributed training could pitch in.


  • related question in the learn2learn repo: https://github.com/learnables/learn2learn/issues/197#issuecomment-784380659
  • higher related issue: https://github.com/facebookresearch/higher/issues/99
  • related: https://stackoverflow.com/questions/69730835/how-does-one-create-a-distributed-data-loader-with-pytorchs-torchmeta

brando90 avatar Feb 23 '21 17:02 brando90

In the (non-higher) example of MAML you are also looping through each task individually. Unfortunately this is necessary in PyTorch because each task has their own set of parameters, so there is no easy way to vectorize it (contrary to a vmap option in Jax, where you can have "batches of parameters").

DDP is unfortunately not supported in Torchmeta at the moment, but even if it did, I doubt this would easily solve looping through the tasks, unless you use low-level distributed routines from PyTorch. I have unfortunately no idea how to handle the distributed setting using the stateless approach in Torchmeta (maybe the stateful approach from Higher or Learn2learn might be easier to handle).

tristandeleu avatar Mar 18 '21 17:03 tristandeleu

@tristandeleu sorry for reviving this thread now. Is torchmeta still not compatible with DDP?

brando90 avatar Sep 09 '21 16:09 brando90

In the (non-higher) example of MAML you are also looping through each task individually. Unfortunately this is necessary in PyTorch because each task has their own set of parameters, so there is no easy way to vectorize it (contrary to a vmap option in Jax, where you can have "batches of parameters").

DDP is unfortunately not supported in Torchmeta at the moment, but even if it did, I doubt this would easily solve looping through the tasks, unless you use low-level distributed routines from PyTorch. I have unfortunately no idea how to handle the distributed setting using the stateless approach in Torchmeta (maybe the stateful approach from Higher or Learn2learn might be easier to handle).

just out of curiosity - why wouldn't torch meta's data loader not work with DDP?

brando90 avatar Sep 09 '21 16:09 brando90

The data-loaders should definitely work with DDP, since they are inheriting from regular PyTorch DataLoader. However when I said that DDP was not supported in Torchmeta at the moment, I meant that there is no equivalent of DDP that would work with Torchmeta's MetaModule. That would require a version of DDP that can handle the params argument in forward, similar to DataParallel, which does not exist in Torchmeta.

tristandeleu avatar Sep 20 '21 14:09 tristandeleu

The data-loaders should definitely work with DDP, since they are inheriting from regular PyTorch DataLoader. However when I said that DDP was not supported in Torchmeta at the moment, I meant that there is no equivalent of DDP that would work with Torchmeta's MetaModule. That would require a version of DDP that can handle the params argument in forward, similar to DataParallel, which does not exist in Torchmeta.

Hi Tristand,

sorry for reviving this thread, but just to clarify my thinking (and planning) - is this what is true: (numbering to ease response)

  1. is Torchmeta compatible with DDP (distributed data loading)?
  2. is the issue with DDP with PyTorch's higher instead and not Torchmeta?

Thanks again for your responses!

brando90 avatar Sep 28 '21 15:09 brando90

for myself in answering 2, I think it's unclear (my guess unlikely) that higher is compatible with DDP, see https://github.com/facebookresearch/higher/issues/116, https://github.com/facebookresearch/higher/issues/98

brando90 avatar Sep 28 '21 16:09 brando90

  1. There are two things in Torchmeta: data-loaders (e.g. Omniglot, MiniImagenet, etc...), and MetaModules to create models where you can backpropagate through the learning rule (e.g. MAML). Data-loaders should be fully compatible with DDP, since they are nothing more than regular PyTorch DataLoader. MetaModules though are not compatible with DDP. Concretely, you should be able to use Torchmeta for data-loading with DDP without issues, but you can't write a model with MetaModule which will be compatible with DDP.
  2. I don't have any experience with DDP and higher so I cannot comment sorry.

One easy way to verify that the data-loaders are compatible with DDP is to test it on an algorithm that doesn't require backpropagation through the learning rule, such as Prototypical Networks.

tristandeleu avatar Sep 28 '21 16:09 tristandeleu

@tristandeleu sorry for the ping.

I am confused, is higher or torchmeta's data loaders incompatible with DDP?

afaik, higher for sure, in summary: is torchmeta's data loaders incompatible with DDP (or just not tested so unknown)?

brando90 avatar Oct 26 '21 19:10 brando90

  1. Data-loaders should be fully compatible with DDP, since they are nothing more than regular PyTorch DataLoader

ok I think it answers my question. It should work but it's untested.

brando90 avatar Oct 26 '21 19:10 brando90

related: https://github.com/learnables/learn2learn/issues/272 from learn2learn

brando90 avatar Oct 26 '21 19:10 brando90

@tristandeleu I made an attempt do wrap a distributed sampler using your helper function for mini-imagenet but it failed it deadlocked.

Here is the attempt in case you are interested: https://github.com/brando90/ultimate-utils/blob/master/tutorials_for_myself/my_torchmeta/torchmeta_ddp.py

brando90 avatar Oct 26 '21 22:10 brando90

@brando90 this is something I am interested in as well. I got DP to work but I would prefer using DDP for the speed, were you able to test out your code?

ojss avatar Nov 26 '21 18:11 ojss

Yes I did, if you look around in the issue you can see a SO question and I say it deadlocks.

Perhaps you help me fix it?

https://stackoverflow.com/questions/69730835/how-does-one-create-a-distributed-data-loader-with-pytorchs-torchmeta-for-meta

On Nov 26, 2021, at 12:49 PM, Ojas Shirekar @.***> wrote:

@brando90 https://github.com/brando90 this is something I am interested in as well. I got DP to work but I would prefer using DDP for the speed, were you able to test out your code?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/tristandeleu/pytorch-meta/issues/116#issuecomment-980290225, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAOE6LUNLTZEW62M4R6FCQTUN7JD7ANCNFSM4YC5RCOQ. Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub.

brando90 avatar Nov 27 '21 02:11 brando90

Yeah I can check it out and see how it goes. Are there somethings you have already tried?

ojss avatar Nov 27 '21 11:11 ojss

Yeah I can check it out and see how it goes. Are there somethings you have already tried?

btw, not sure if you are using higher for MAML but I was told that DDP wouldn't work for it...that's why I stopped (especially after the deadlock). Do you know anything about this? is this an issue for you?

brando90 avatar Nov 27 '21 15:11 brando90

@brando90 this is something I am interested in as well. I got DP to work but I would prefer using DDP for the speed, were you able to test out your code?

what is DP?

brando90 avatar Nov 27 '21 15:11 brando90

here is the SO question: https://stackoverflow.com/questions/69730835/how-does-one-create-a-distributed-data-loader-with-pytorchs-torchmeta-for-meta

brando90 avatar Nov 27 '21 15:11 brando90

@brando90 this is something I am interested in as well. I got DP to work but I would prefer using DDP for the speed, were you able to test out your code?

what is DP?

That is DataParallel, the old way.

ojss avatar Nov 30 '21 09:11 ojss

sounds good.

I will try to put another 500 bounty to see if it helps in the SO question. I don't have the cycles to help more but I hope it helps.

On Nov 30, 2021, at 3:56 AM, Ojas Shirekar @.***> wrote:

@brando90 https://github.com/brando90 this is something I am interested in as well. I got DP to work but I would prefer using DDP for the speed, were you able to test out your code?

what is DP?

That is DataParallel, the old way.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/tristandeleu/pytorch-meta/issues/116#issuecomment-982471493, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAOE6LTK6Y534XWPQRX63GTUOSNTXANCNFSM4YC5RCOQ. Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub.

brando90 avatar Nov 30 '21 17:11 brando90