Sparse_SwitchNorm
Sparse_SwitchNorm copied to clipboard
Cannot apply sparse_switchnorm to 2D input?
When I want to apply sparse_switchnorm to a 2D tensor, it fails at self.var_weight and meets the same problem as #2 ? I modified the code as follows.
`class SSN(nn.Module): def init(self, num_features, eps=1e-5, momentum=0.997, using_moving_average=True, last_gamma=False): super(SSN1d, self).init() self.eps = eps self.momentum = momentum self.using_moving_average = using_moving_average self.weight = nn.Parameter(torch.ones(1, num_features)) self.bias = nn.Parameter(torch.zeros(1, num_features))
self.mean_weight = nn.Parameter(torch.ones(2))
self.var_weight = nn.Parameter(torch.ones(2))
self.register_buffer('running_mean', torch.zeros(1, num_features))
self.register_buffer('running_var', torch.zeros(1, num_features))
# self.rad = 0.
self.register_buffer('mean_fixed', torch.LongTensor([0]))
self.register_buffer('var_fixed', torch.LongTensor([0]))
self.register_buffer('radius', torch.zeros(1))
self.mean_weight_ = torch.cuda.FloatTensor([1.,1.])
self.var_weight_ = torch.cuda.FloatTensor([1.,1.])
self.reset_parameters()
def reset_parameters(self):
self.running_mean.zero_()
self.running_var.zero_()
self.weight.data.fill_(1)
self.mean_fixed.data.fill_(0)
self.var_fixed.data.fill_(0)
self.bias.data.zero_()
def _check_input_dim(self, input):
if input.dim() != 2:
raise ValueError('expected 2D input (got {}D input)'
.format(input.dim()))
def forward(self, x):
self._check_input_dim(x)
mean_ln = x.mean(1, keepdim=True)
var_ln = x.var(1, keepdim=True)
if self.training:
mean_bn = x.mean(0, keepdim=True)
var_bn = x.var(0, keepdim=True)
if self.using_moving_average:
self.running_mean.mul_(self.momentum)
self.running_mean.add_((1 - self.momentum) * mean_bn.data)
self.running_var.mul_(self.momentum)
self.running_var.add_((1 - self.momentum) * var_bn.data)
else:
self.running_mean.add_(mean_bn.data)
self.running_var.add_(mean_bn.data ** 2 + var_bn.data)
else:
mean_bn = torch.autograd.Variable(self.running_mean)
var_bn = torch.autograd.Variable(self.running_var)
rad = self.radius.item()
if not self.mean_fixed:
self.mean_weight_ = sparsestmax(self.mean_weight, rad)
if max(self.mean_weight_) - min(self.mean_weight_) >= 1:
self.mean_fixed.data.fill_(1)
self.mean_weight.data = self.mean_weight_.data
self.mean_weight_ = self.mean_weight.detach()
else:
self.mean_weight_ = self.mean_weight.detach()
if not self.var_fixed:
**self.var_weight_ = sparsestmax(self.var_weight, rad)**
if max(self.var_weight_) - min(self.var_weight_) >= 1:
self.var_fixed.data.fill_(1)
self.var_weight.data = self.var_weight_.data
self.var_weight_ = self.var_weight.detach()
else:
self.var_weight_ = self.var_weight.detach()
mean = self.mean_weight_[0] * mean_ln + self.mean_weight_[1] * mean_bn
var = self.var_weight_[0] * var_ln + self.var_weight_[1] * var_bn
x = (x - mean) / (var + self.eps).sqrt()
return x * self.weight + self.bias
def get_mean(self):
return self.mean_weight_
def get_var(self):
return self.var_weight_
def set_rad(self, rad):
self.radius[0].fill_(rad)
# self.rad = torch.squeeze(self.radius)
def get_rad(self):
return torch.squeeze(self.radius)`
`
It seems that the value of self.var_weight in "self.var_weight_ = sparsestmax(self.var_weight, rad)" is nan. Is it because of an error in the modified code? Or sparse_switchnorm cannot apply to the 2D input?