megcup-feedforward icon indicating copy to clipboard operation
megcup-feedforward copied to clipboard

Code for Training

Open Guaishou74851 opened this issue 3 years ago • 1 comments

你好,想问下是否有模型训练代码呢?主要想参考下Loss设计部分,谢谢!

Guaishou74851 avatar Apr 28 '22 03:04 Guaishou74851

由于近期我们组员都有些其他工作需要处理,所以训练代码最近可能不能公开。 不过Loss部分您可以先参考下这个:

from abc import abstractmethod

from megengine import module as M
import megengine.functional as F

class BaseLoss(M.Module):
    def __init__(self, name) -> None:
        super().__init__()
        self.name = name
        self.losses = []

    def __len__(self):
        return len(self.losses)

    @abstractmethod
    def forward(self, x, y):
        pass

class DiffLoss(BaseLoss):
    def __init__(self, loss_weight):
        super().__init__('diff loss')
        self.loss_weight = loss_weight

    def forward(self, x, y):
        means = y.mean(axis=(-1, -2, -3))
        weight = (1 / means) ** 0.5
        diff = F.abs(x - y).mean(axis=(-1, -2, -3))
        diff = diff * weight.detach()
        diff = diff.mean()
        return diff * self.loss_weight

Srameo avatar Apr 29 '22 04:04 Srameo