apex icon indicating copy to clipboard operation
apex copied to clipboard

syncbn with "channel_last=True" produce wrong result when feature_num is not pow-of-two

Open Zehaos opened this issue 1 year ago • 1 comments

Describe the Bug When feature_num is not pow-of-two, apex.parallel.SyncBatchNorm will produce wrong result.

I test it step by step, and found it produce wrong mean and var when feature_h and feature_w is large enough (see the minimal reproduce code below).

Minimal Steps/Code to Reproduce the Bug

when feature h, feature w is small, it produce correct result.

import torch
import syncbn
feature_size = 65 # not pow-of-two
feature_h = 10
feature_w = 10
# when feature_h, feature_w is small, it produce correct mean and var
input = torch.rand(1, feature_size, feature_h, feature_w)
input_clast = input.permute([0,2,3,1]).contiguous()
var, mean = torch.var_mean(input_clast, dim=[0,1,2], unbiased=False)
mean_apex, var_apex = syncbn.welford_mean_var_c_last(input_clast)
torch.allclose(mean, mean_apex) # it is true

when feature h, feature w is large, it produce wrong result.

import torch
import syncbn
feature_size = 65 # not pow-of-two
feature_h = 100
feature_w = 100
# when feature_h, feature_w is large, it produce wrong mean and var
input = torch.rand(1, feature_size, feature_h, feature_w)
input_clast = input.permute([0,2,3,1]).contiguous()
var, mean = torch.var_mean(input_clast, dim=[0,1,2], unbiased=False)
mean_apex, var_apex = syncbn.welford_mean_var_c_last(input_clast)
torch.allclose(mean, mean_apex) # it is False

Expected Behavior

Environment

ngc_23.11

Zehaos avatar Jan 14 '24 14:01 Zehaos

cc @jjsjann123

Zehaos avatar Jan 14 '24 14:01 Zehaos