CE-Net icon indicating copy to clipboard operation
CE-Net copied to clipboard

Peradventure about MulticlassDiceLoss

Open Chao86 opened this issue 5 years ago • 0 comments

First of all, I would like to thank you for disclosing your own code, which can give me a chance to learn. After practice your code, I have one question about `

class MulticlassDiceLoss(nn.Module): """ requires one hot encoded target. Applies DiceLoss on each class iteratively. requires input.shape[0:1] and target.shape[0:1] to be (N, C) where N is batch size and C is number of classes """ def init(self):

    super(MulticlassDiceLoss, self).__init__()

def forward(self, input, target, weights=None):

    C = target.shape[1]
 
    totalLoss = 0

    for i in range(C):

        diceLoss = dice(input[:, i, :, :], target[:, i, :, :])

        if weights is not None:

            diceLoss *= weights[i]

        totalLoss += diceLoss

    return totalLoss

`, the 'input' is ground truth,it has only one channel,how to transform the channel into channels which are equal to the predicted segmentation image?

Chao86 avatar Jul 15 '19 03:07 Chao86