apex
apex copied to clipboard
syncbn with "channel_last=True" produce wrong result when feature_num is not pow-of-two
Describe the Bug When feature_num is not pow-of-two, apex.parallel.SyncBatchNorm will produce wrong result.
I test it step by step, and found it produce wrong mean and var when feature_h and feature_w is large enough (see the minimal reproduce code below).
Minimal Steps/Code to Reproduce the Bug
when feature h, feature w is small, it produce correct result.
import torch
import syncbn
feature_size = 65 # not pow-of-two
feature_h = 10
feature_w = 10
# when feature_h, feature_w is small, it produce correct mean and var
input = torch.rand(1, feature_size, feature_h, feature_w)
input_clast = input.permute([0,2,3,1]).contiguous()
var, mean = torch.var_mean(input_clast, dim=[0,1,2], unbiased=False)
mean_apex, var_apex = syncbn.welford_mean_var_c_last(input_clast)
torch.allclose(mean, mean_apex) # it is true
when feature h, feature w is large, it produce wrong result.
import torch
import syncbn
feature_size = 65 # not pow-of-two
feature_h = 100
feature_w = 100
# when feature_h, feature_w is large, it produce wrong mean and var
input = torch.rand(1, feature_size, feature_h, feature_w)
input_clast = input.permute([0,2,3,1]).contiguous()
var, mean = torch.var_mean(input_clast, dim=[0,1,2], unbiased=False)
mean_apex, var_apex = syncbn.welford_mean_var_c_last(input_clast)
torch.allclose(mean, mean_apex) # it is False
Expected Behavior
Environment
ngc_23.11
cc @jjsjann123