vision icon indicating copy to clipboard operation
vision copied to clipboard

MaxVIT Model - BatchNorm momentum is incorrect

Open hassonofer opened this issue 1 year ago • 6 comments

🐛 Describe the bug

Current BatchNorm momentum is set to 0.99 here as noted, this was taken from the original implementation here

But due to the differences between PyTorch and TensoFlow implementation of BatchNorm, the momentum should be 1-momentum in TorchVision implementation. As done (correctly to my understanding) at the MnasNet implementation.

Versions

Collecting environment information... PyTorch version: 2.1.2+cu118 Is debug build: False CUDA used to build PyTorch: 11.8 ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 11 (bullseye) (x86_64) GCC version: (Debian 10.2.1-6) 10.2.1 20210110 Clang version: Could not collect CMake version: version 3.28.1 Libc version: glibc-2.31

Python version: 3.9.2 (default, Feb 28 2021, 17:03:44) [GCC 10.2.1 20210110] (64-bit runtime) Python platform: Linux-6.1.0-0.deb11.13-amd64-x86_64-with-glibc2.31 Is CUDA available: True CUDA runtime version: 11.8.89 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA RTX A5000 GPU 1: NVIDIA RTX A5000

Nvidia driver version: 545.23.08 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

Versions of relevant libraries: [pip3] flake8==7.0.0 [pip3] flake8-pep585==0.1.7 [pip3] mypy==1.8.0 [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.26.3 [pip3] onnx==1.15.0 [pip3] torch==2.1.2+cu118 [pip3] torch-model-archiver==0.9.0 [pip3] torch-workflow-archiver==0.2.11 [pip3] torchaudio==2.1.2+cu118 [pip3] torchinfo==1.8.0 [pip3] torchmetrics==1.3.0.post0 [pip3] torchserve==0.9.0 [pip3] torchvision==0.16.2+cu118 [pip3] triton==2.1.0 [conda] Could not collect

hassonofer avatar Feb 02 '24 21:02 hassonofer

Thank you for the report @hassonofer.

Just checking with @TeodorPoncu before moving forward: Teodor do you remember discussing this during reviews?

NicolasHug avatar Feb 05 '24 09:02 NicolasHug

Hey @NicolasHug! I personally was not aware of that difference in parameter specification between PyTorch and TensorFlow.

I do not recall that coming up during reviews (I've double checked with the original PR).

I assume that detail might've flown under the radar due to us obtaining comparable results to the tiny variant from the paper (83.7 for the torchvision weights and 83.62 for the paper).

TeodorPoncu avatar Feb 05 '24 10:02 TeodorPoncu

Thank you for your quick reply @TeodorPoncu !

Since the momentum parameter is only affecting the training of the model, and not inference (right?), we can probably fix the default from 0.99 to 0.01, and that would still keep the pre-trained weights perf the same (i.e. it would still be 83.7). WDYT?

NicolasHug avatar Feb 05 '24 10:02 NicolasHug

@NicolasHug , yes. During inference time the momentum parameter has no effect on batch-norm as it uses the running means and variances for inference and the evaluation performance will be the exact same. The momentum parameter affects how these statistics are estimated during training time.

Momentum was introduced in this paper to counteract small batch-sizes relative to the dataset size. The reason behind this is that the default way of computing the running mean and variance for the Batch Norm layer is done via a non stationary momentum (1 / gamma) (torch reference implementation here) where gamma is incremented by 1 at every forward pass during training.

As such, the longer the training run goes, a batch will contribute less and less to the statistics update when not setting a momentum value. Depending on how the underlying implementation is in torch (I could not find if does something like running_mean = momentum * running_mean + (1 - momentum) * batch_mean, as the ref. implementation fallsback to a F binding), changing the momentum to something to 0.01 might affect users that are performing subsequent fine-tuning runs.

For instance, if the default torch implementation does the above (which is the same way the paper describes it in eq.8 and algorithm 1.), users might notice unfavourable results when performing fine-tuning based on how they configure the momentum parameter.

For a small-dataset, if they do not change the momentum inside the batch-norm layers then the learned statistics for the ImageNet solution space will be immediately washed away (since we will be assigning a weight of 0.01 to it), thus missing on the benefits of transfer learning.

I would recommend changing the default value to 0.01 if and only if the actual torch implementation does running_mean = (1 - momentum * running_mean) + momentum * batch_mean given the side-effects in can lead to in fine-tuning scenarios.

TeodorPoncu avatar Feb 05 '24 11:02 TeodorPoncu

Thanks @TeodorPoncu - yeah as far as I can tell from the Note in https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html, the formula is as you described

NicolasHug avatar Feb 05 '24 14:02 NicolasHug

@NicolasHug, in that case yes, the appropriate default momentum value should be set to 0.01 and it shouldn't have any side effects (inference or fine-tuning wise).

TeodorPoncu avatar Feb 05 '24 16:02 TeodorPoncu