vision
vision copied to clipboard
MaxVIT Model - BatchNorm momentum is incorrect
🐛 Describe the bug
Current BatchNorm momentum is set to 0.99 here as noted, this was taken from the original implementation here
But due to the differences between PyTorch and TensoFlow implementation of BatchNorm, the momentum should be 1-momentum
in TorchVision implementation.
As done (correctly to my understanding) at the MnasNet implementation.
Versions
Collecting environment information... PyTorch version: 2.1.2+cu118 Is debug build: False CUDA used to build PyTorch: 11.8 ROCM used to build PyTorch: N/A
OS: Debian GNU/Linux 11 (bullseye) (x86_64) GCC version: (Debian 10.2.1-6) 10.2.1 20210110 Clang version: Could not collect CMake version: version 3.28.1 Libc version: glibc-2.31
Python version: 3.9.2 (default, Feb 28 2021, 17:03:44) [GCC 10.2.1 20210110] (64-bit runtime) Python platform: Linux-6.1.0-0.deb11.13-amd64-x86_64-with-glibc2.31 Is CUDA available: True CUDA runtime version: 11.8.89 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA RTX A5000 GPU 1: NVIDIA RTX A5000
Nvidia driver version: 545.23.08 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True
Versions of relevant libraries: [pip3] flake8==7.0.0 [pip3] flake8-pep585==0.1.7 [pip3] mypy==1.8.0 [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.26.3 [pip3] onnx==1.15.0 [pip3] torch==2.1.2+cu118 [pip3] torch-model-archiver==0.9.0 [pip3] torch-workflow-archiver==0.2.11 [pip3] torchaudio==2.1.2+cu118 [pip3] torchinfo==1.8.0 [pip3] torchmetrics==1.3.0.post0 [pip3] torchserve==0.9.0 [pip3] torchvision==0.16.2+cu118 [pip3] triton==2.1.0 [conda] Could not collect
Thank you for the report @hassonofer.
Just checking with @TeodorPoncu before moving forward: Teodor do you remember discussing this during reviews?
Hey @NicolasHug! I personally was not aware of that difference in parameter specification between PyTorch and TensorFlow.
I do not recall that coming up during reviews (I've double checked with the original PR).
I assume that detail might've flown under the radar due to us obtaining comparable results to the tiny variant from the paper (83.7
for the torchvision weights and 83.62
for the paper).
Thank you for your quick reply @TeodorPoncu !
Since the momentum
parameter is only affecting the training of the model, and not inference (right?), we can probably fix the default from 0.99 to 0.01, and that would still keep the pre-trained weights perf the same (i.e. it would still be 83.7). WDYT?
@NicolasHug , yes. During inference time the momentum parameter has no effect on batch-norm as it uses the running means and variances for inference and the evaluation performance will be the exact same. The momentum parameter affects how these statistics are estimated during training time.
Momentum was introduced in this paper to counteract small batch-sizes relative to the dataset size. The reason behind this is that the default way of computing the running mean and variance for the Batch Norm layer is done via a non stationary momentum (1 / gamma
) (torch reference implementation here) where gamma is incremented by 1 at every forward pass during training.
As such, the longer the training run goes, a batch will contribute less and less to the statistics update when not setting a momentum value. Depending on how the underlying implementation is in torch
(I could not find if does something like running_mean = momentum * running_mean + (1 - momentum) * batch_mean
, as the ref. implementation fallsback to a F
binding), changing the momentum to something to 0.01
might affect users that are performing subsequent fine-tuning runs.
For instance, if the default torch implementation does the above (which is the same way the paper describes it in eq.8 and algorithm 1.), users might notice unfavourable results when performing fine-tuning based on how they configure the momentum parameter.
For a small-dataset, if they do not change the momentum inside the batch-norm layers then the learned statistics for the ImageNet solution space will be immediately washed away (since we will be assigning a weight of 0.01
to it), thus missing on the benefits of transfer learning.
I would recommend changing the default value to 0.01 if and only if the actual torch implementation does running_mean = (1 - momentum * running_mean) + momentum * batch_mean
given the side-effects in can lead to in fine-tuning scenarios.
Thanks @TeodorPoncu - yeah as far as I can tell from the Note in https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html, the formula is as you described
@NicolasHug, in that case yes, the appropriate default
momentum value should be set to 0.01 and it shouldn't have any side effects (inference or fine-tuning wise).