pytorch-meta
pytorch-meta copied to clipboard
Using distributed training with torch meta (e.g. DDP, distributed RPC)
I've noticed that my training take a long to train, probably because when using higher one has to loop through each batch individually as in
https://github.com/tristandeleu/pytorch-meta/blob/389e35ef9aa812f07ce50a3f3bd253c4efb9765c/examples/maml-higher/train.py#L113
I noticed that one parallelize this part of the code but since torchmeta doesn't work with DDP (or at least it's unclear if it does) I decided that perhaps that using distributed RPC might be a good way to do it - though it's not been as trivial as I thought. Thus, I was wondering if this path was worth it if anyone has tried it or perhaps someone with more experience with distributed training could pitch in.
- related question in the learn2learn repo: https://github.com/learnables/learn2learn/issues/197#issuecomment-784380659
- higher related issue: https://github.com/facebookresearch/higher/issues/99
- related: https://stackoverflow.com/questions/69730835/how-does-one-create-a-distributed-data-loader-with-pytorchs-torchmeta
In the (non-higher) example of MAML you are also looping through each task individually. Unfortunately this is necessary in PyTorch because each task has their own set of parameters, so there is no easy way to vectorize it (contrary to a vmap
option in Jax, where you can have "batches of parameters").
DDP is unfortunately not supported in Torchmeta at the moment, but even if it did, I doubt this would easily solve looping through the tasks, unless you use low-level distributed routines from PyTorch. I have unfortunately no idea how to handle the distributed setting using the stateless approach in Torchmeta (maybe the stateful approach from Higher or Learn2learn might be easier to handle).
@tristandeleu sorry for reviving this thread now. Is torchmeta still not compatible with DDP?
In the (non-higher) example of MAML you are also looping through each task individually. Unfortunately this is necessary in PyTorch because each task has their own set of parameters, so there is no easy way to vectorize it (contrary to a
vmap
option in Jax, where you can have "batches of parameters").DDP is unfortunately not supported in Torchmeta at the moment, but even if it did, I doubt this would easily solve looping through the tasks, unless you use low-level distributed routines from PyTorch. I have unfortunately no idea how to handle the distributed setting using the stateless approach in Torchmeta (maybe the stateful approach from Higher or Learn2learn might be easier to handle).
just out of curiosity - why wouldn't torch meta's data loader not work with DDP?
The data-loaders should definitely work with DDP, since they are inheriting from regular PyTorch DataLoader
. However when I said that DDP was not supported in Torchmeta at the moment, I meant that there is no equivalent of DDP that would work with Torchmeta's MetaModule
. That would require a version of DDP that can handle the params
argument in forward
, similar to DataParallel
, which does not exist in Torchmeta.
The data-loaders should definitely work with DDP, since they are inheriting from regular PyTorch
DataLoader
. However when I said that DDP was not supported in Torchmeta at the moment, I meant that there is no equivalent of DDP that would work with Torchmeta'sMetaModule
. That would require a version of DDP that can handle theparams
argument inforward
, similar toDataParallel
, which does not exist in Torchmeta.
Hi Tristand,
sorry for reviving this thread, but just to clarify my thinking (and planning) - is this what is true: (numbering to ease response)
- is Torchmeta compatible with DDP (distributed data loading)?
- is the issue with DDP with PyTorch's higher instead and not Torchmeta?
Thanks again for your responses!
for myself in answering 2, I think it's unclear (my guess unlikely) that higher is compatible with DDP, see https://github.com/facebookresearch/higher/issues/116, https://github.com/facebookresearch/higher/issues/98
- There are two things in Torchmeta: data-loaders (e.g.
Omniglot
,MiniImagenet
, etc...), andMetaModule
s to create models where you can backpropagate through the learning rule (e.g. MAML). Data-loaders should be fully compatible with DDP, since they are nothing more than regular PyTorchDataLoader
.MetaModule
s though are not compatible with DDP. Concretely, you should be able to use Torchmeta for data-loading with DDP without issues, but you can't write a model withMetaModule
which will be compatible with DDP. - I don't have any experience with DDP and higher so I cannot comment sorry.
One easy way to verify that the data-loaders are compatible with DDP is to test it on an algorithm that doesn't require backpropagation through the learning rule, such as Prototypical Networks.
@tristandeleu sorry for the ping.
I am confused, is higher or torchmeta's data loaders incompatible with DDP?
afaik, higher for sure, in summary: is torchmeta's data loaders incompatible with DDP (or just not tested so unknown)?
- Data-loaders should be fully compatible with DDP, since they are nothing more than regular PyTorch
DataLoader
ok I think it answers my question. It should work but it's untested.
related: https://github.com/learnables/learn2learn/issues/272 from learn2learn
@tristandeleu I made an attempt do wrap a distributed sampler using your helper function for mini-imagenet but it failed it deadlocked.
Here is the attempt in case you are interested: https://github.com/brando90/ultimate-utils/blob/master/tutorials_for_myself/my_torchmeta/torchmeta_ddp.py
@brando90 this is something I am interested in as well. I got DP to work but I would prefer using DDP for the speed, were you able to test out your code?
Yes I did, if you look around in the issue you can see a SO question and I say it deadlocks.
Perhaps you help me fix it?
https://stackoverflow.com/questions/69730835/how-does-one-create-a-distributed-data-loader-with-pytorchs-torchmeta-for-meta
On Nov 26, 2021, at 12:49 PM, Ojas Shirekar @.***> wrote:
@brando90 https://github.com/brando90 this is something I am interested in as well. I got DP to work but I would prefer using DDP for the speed, were you able to test out your code?
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/tristandeleu/pytorch-meta/issues/116#issuecomment-980290225, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAOE6LUNLTZEW62M4R6FCQTUN7JD7ANCNFSM4YC5RCOQ. Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub.
Yeah I can check it out and see how it goes. Are there somethings you have already tried?
Yeah I can check it out and see how it goes. Are there somethings you have already tried?
btw, not sure if you are using higher for MAML but I was told that DDP wouldn't work for it...that's why I stopped (especially after the deadlock). Do you know anything about this? is this an issue for you?
@brando90 this is something I am interested in as well. I got DP to work but I would prefer using DDP for the speed, were you able to test out your code?
what is DP?
here is the SO question: https://stackoverflow.com/questions/69730835/how-does-one-create-a-distributed-data-loader-with-pytorchs-torchmeta-for-meta
@brando90 this is something I am interested in as well. I got DP to work but I would prefer using DDP for the speed, were you able to test out your code?
what is DP?
That is DataParallel, the old way.
sounds good.
I will try to put another 500 bounty to see if it helps in the SO question. I don't have the cycles to help more but I hope it helps.
On Nov 30, 2021, at 3:56 AM, Ojas Shirekar @.***> wrote:
@brando90 https://github.com/brando90 this is something I am interested in as well. I got DP to work but I would prefer using DDP for the speed, were you able to test out your code?
what is DP?
That is DataParallel, the old way.
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/tristandeleu/pytorch-meta/issues/116#issuecomment-982471493, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAOE6LTK6Y534XWPQRX63GTUOSNTXANCNFSM4YC5RCOQ. Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub.