DeepInversion icon indicating copy to clipboard operation
DeepInversion copied to clipboard

Calculate batch norm statistic loss on parallel training

Open dohe0342 opened this issue 3 years ago • 3 comments

Hello, I have one question about batch norm statistic loss.

Consider parallel training. I have 8 GPUs. and 1 gpu can bear 128 batch size.

But you know, batch norm statistic loss is calculated on each machine and each machine share their gradients not whole batch(1024). And I think this can cause image quality degradation.

So, here is my question. How can I calculate batch norm statistic loss on parallel training just like calculating whole batch size not mini-batch

dohe0342 avatar May 06 '21 07:05 dohe0342

If you are using DistributedDataParallel, try to convert BatchNorm layers to SyncBatchNorm ones.

hkunzhe avatar May 13 '21 07:05 hkunzhe

I know about SyncBatchNorm.But DeepInversion should calculate loss about each pixel and my gpu can't bear it.

dohe0342 avatar May 28 '21 08:05 dohe0342

Hi @dohe0342 one way to try is to reduce batch size to alleviate the GPU burden. Also try using setting 2k iteration one to save on GPU burdern. Additionally you can try to use the dataset synthesized we provided in the repository. Let me know if it helps.

hongxuyin avatar Jul 30 '22 08:07 hongxuyin