dynconv icon indicating copy to clipboard operation
dynconv copied to clipboard

About multi-gpu training

Open d-li14 opened this issue 5 years ago • 5 comments

Thanks for your awesome work! Is there any idea how multi-gpu training is supported? Because you know training ResNet-101 on ImageNet with a single GPU is unacceptably slow.

d-li14 avatar Jun 21 '20 08:06 d-li14

Hi, thanks for having a look at the code. I did not test dual-gpu training, and RN101 indeed takes quite some time on single GPU (~2 weeks). I did not do the effort of implementing multi-gpu support, since I had to use the other available GPUs in our lab for other runs/experiments. I suspect some changes are needed in the loss. I was planning to look at it anyway in the coming weeks, I'll let you know!

I also plan to release a trained mobilenetv2 with the optimized CUDA code integrated.

thomasverelst avatar Jun 21 '20 20:06 thomasverelst

Hi, @thomasverelst Thanks for your prompt reply and sharing! I have realized your concern about the computational resource, but two weeks is still a fairly long experimental period :).

Furthermore, I have made attempts towards multi-gpu training by simply wrapping the model with torch.nn.DataParallel, but was stucked in some issues:

  • gather the output dict meta across GPUs (possibly I have solved this)
  • the weights of self-constructed tensors here probably cannot be replicated to other GPUs from GPU 0

Looking forward to your good news! Also congratulations on the upcoming MobileNetV2 CUDA code!

d-li14 avatar Jun 22 '20 01:06 d-li14

I've pushed a new branch multigpu. I didn't test training accuracy yet, but it runs. I only had problems with gathering the output dict meta. I considered subclassing DataParallel to support meta but decided to just change the internal working so PyTorch wouldn't complain. Note that the pretrained checkpoints are different from the master branch (url in README).

thomasverelst avatar Jun 22 '20 17:06 thomasverelst

Yeah, it seems to work now. I have successfully run this branch with ResNet-32 on CIFAR for fast prototyping (with matched accuracy and reduced FLOPs). As an additional note, the "FLOPs counting to zero" problem can be solved by modifying the following line https://github.com/thomasverelst/dynconv/blob/multigpu/classification/main_cifar.py#L204 model = flopscounter.add_flops_counting_methods(model) to model = flopscounter.add_flops_counting_methods(model.module), due to the DataParallel wrapping.

d-li14 avatar Jun 23 '20 02:06 d-li14

Thanks a lot, that fixed it.

thomasverelst avatar Jun 23 '20 08:06 thomasverelst