megcup-feedforward
megcup-feedforward copied to clipboard
Code for Training
你好,想问下是否有模型训练代码呢?主要想参考下Loss设计部分,谢谢!
由于近期我们组员都有些其他工作需要处理,所以训练代码最近可能不能公开。 不过Loss部分您可以先参考下这个:
from abc import abstractmethod
from megengine import module as M
import megengine.functional as F
class BaseLoss(M.Module):
def __init__(self, name) -> None:
super().__init__()
self.name = name
self.losses = []
def __len__(self):
return len(self.losses)
@abstractmethod
def forward(self, x, y):
pass
class DiffLoss(BaseLoss):
def __init__(self, loss_weight):
super().__init__('diff loss')
self.loss_weight = loss_weight
def forward(self, x, y):
means = y.mean(axis=(-1, -2, -3))
weight = (1 / means) ** 0.5
diff = F.abs(x - y).mean(axis=(-1, -2, -3))
diff = diff * weight.detach()
diff = diff.mean()
return diff * self.loss_weight