about loss = nan
Hi, i have met the problem loss=nan too.Here is my solution. Because the loss function is (sqrt(g_gound)-sqrt(g_hat))^2,grad will be nan when g_hat is 0.The code below may fix the problem: class CustomLoss(nn.Module): ... def forward(...): .... rb=targets[:,:,34:68]
try to avoid nan
mask = gb_hat<0.0003 gamma_gb_hat=torch.FloatTensor(gb_hat.size()).type_as(gb_hat) gamma_gb_hat=1290*gb_hat[mask] mask = gb_hat>=0.0003 gamma_gb_hat[mask]=torch,pow(gb_hat[mask],gamma)
mask = (1-rb_hat)<0.0003 gamma_rb_hat=torch.FloatTensor(rb_hat.size()).type_as(rb_hat) gamma_rb_hat=1290*(1-rb_hat[mask]) mask = (1-rb_hat)>=0.0003 gamma_rb_hat[mask]=torch,pow((1-rb_hat[mask]),gamma)
return torch.mean(torch.pow( ( torch.pow(gb,gamma) - gamma_gb_hat ),2 ) )
+ C4 * torch.mean(torch.pow( ( torch.pow(gb,gamma) - gamma_gb_hat ),4 ) )
+ torch.mean(torch.pow( ( torch.pow(1-rb,gamma) - gamma_rb_hat ),2 ) )\
for my case, nan loss did not appear with DNS-Challenge dataset. I think someone who has Nan loss problem could try this solution.