Sparse_SwitchNorm icon indicating copy to clipboard operation
Sparse_SwitchNorm copied to clipboard

Cannot apply sparse_switchnorm to 2D input?

Open Ning5195 opened this issue 4 years ago • 0 comments

When I want to apply sparse_switchnorm to a 2D tensor, it fails at self.var_weight and meets the same problem as #2 ? I modified the code as follows.

`class SSN(nn.Module): def init(self, num_features, eps=1e-5, momentum=0.997, using_moving_average=True, last_gamma=False): super(SSN1d, self).init() self.eps = eps self.momentum = momentum self.using_moving_average = using_moving_average self.weight = nn.Parameter(torch.ones(1, num_features)) self.bias = nn.Parameter(torch.zeros(1, num_features))

    self.mean_weight = nn.Parameter(torch.ones(2))
    self.var_weight = nn.Parameter(torch.ones(2))
    self.register_buffer('running_mean', torch.zeros(1, num_features))
    self.register_buffer('running_var', torch.zeros(1, num_features))

    # self.rad = 0.
    self.register_buffer('mean_fixed', torch.LongTensor([0]))
    self.register_buffer('var_fixed', torch.LongTensor([0]))
    self.register_buffer('radius', torch.zeros(1))

    self.mean_weight_ = torch.cuda.FloatTensor([1.,1.])
    self.var_weight_ = torch.cuda.FloatTensor([1.,1.])

    self.reset_parameters()

def reset_parameters(self):
    self.running_mean.zero_()
    self.running_var.zero_()
    self.weight.data.fill_(1)
    self.mean_fixed.data.fill_(0)
    self.var_fixed.data.fill_(0)
    self.bias.data.zero_()

def _check_input_dim(self, input):
    if input.dim() != 2:
        raise ValueError('expected 2D input (got {}D input)'
                         .format(input.dim()))

def forward(self, x):
    self._check_input_dim(x)

    mean_ln = x.mean(1, keepdim=True)
    var_ln = x.var(1, keepdim=True)

    if self.training:
        mean_bn = x.mean(0, keepdim=True)
        var_bn = x.var(0, keepdim=True)
        if self.using_moving_average:
            self.running_mean.mul_(self.momentum)
            self.running_mean.add_((1 - self.momentum) * mean_bn.data)
            self.running_var.mul_(self.momentum)
            self.running_var.add_((1 - self.momentum) * var_bn.data)
        else:
            self.running_mean.add_(mean_bn.data)
            self.running_var.add_(mean_bn.data ** 2 + var_bn.data)
    else:
        mean_bn = torch.autograd.Variable(self.running_mean)
        var_bn = torch.autograd.Variable(self.running_var)

    rad = self.radius.item()
    if not self.mean_fixed:
        self.mean_weight_ = sparsestmax(self.mean_weight, rad)
        if max(self.mean_weight_) - min(self.mean_weight_) >= 1:
            self.mean_fixed.data.fill_(1)
            self.mean_weight.data = self.mean_weight_.data
            self.mean_weight_ = self.mean_weight.detach()
    else:
        self.mean_weight_ = self.mean_weight.detach()

    if not self.var_fixed:
        **self.var_weight_ = sparsestmax(self.var_weight, rad)**
        if max(self.var_weight_) - min(self.var_weight_) >= 1:
            self.var_fixed.data.fill_(1)
            self.var_weight.data = self.var_weight_.data
            self.var_weight_ = self.var_weight.detach()
    else:
        self.var_weight_ = self.var_weight.detach()

    mean = self.mean_weight_[0] * mean_ln + self.mean_weight_[1] * mean_bn
    var = self.var_weight_[0] * var_ln + self.var_weight_[1] * var_bn

    x = (x - mean) / (var + self.eps).sqrt()
    return x * self.weight + self.bias

def get_mean(self):
    return self.mean_weight_

def get_var(self):
    return self.var_weight_

def set_rad(self, rad):
    self.radius[0].fill_(rad)
    # self.rad = torch.squeeze(self.radius)

def get_rad(self):
    return torch.squeeze(self.radius)`

`

It seems that the value of self.var_weight in "self.var_weight_ = sparsestmax(self.var_weight, rad)" is nan. Is it because of an error in the modified code? Or sparse_switchnorm cannot apply to the 2D input?

Ning5195 avatar Oct 24 '20 03:10 Ning5195