DORN_pytorch
DORN_pytorch copied to clipboard
BerHuber implementation
Hi, thank you for your contribution. Following the BerHuber loss that u implemented, I am not sure if that is the correct implementation. I would suggest something like this:
class berHuLoss(nn.Module): def init(self): super(berHuLoss, self).init()
def forward(self, pred, target):
assert pred.dim() == target.dim(), "inconsistent dimensions"
valid_mask = (target > 0).detach()
x_abs = (pred - target).abs()
x_abs = x_abs[valid_mask]
c = 0.2 * torch.max(x_abs)
loss = torch.where(x_abs > c, (x_abs ** 2 + c ** 2) / (2 * c), x_abs).mean()
return loss
what do u think?