CE-Net
CE-Net copied to clipboard
Peradventure about MulticlassDiceLoss
First of all, I would like to thank you for disclosing your own code, which can give me a chance to learn. After practice your code, I have one question about `
class MulticlassDiceLoss(nn.Module): """ requires one hot encoded target. Applies DiceLoss on each class iteratively. requires input.shape[0:1] and target.shape[0:1] to be (N, C) where N is batch size and C is number of classes """ def init(self):
super(MulticlassDiceLoss, self).__init__()
def forward(self, input, target, weights=None):
C = target.shape[1]
totalLoss = 0
for i in range(C):
diceLoss = dice(input[:, i, :, :], target[:, i, :, :])
if weights is not None:
diceLoss *= weights[i]
totalLoss += diceLoss
return totalLoss
`, the 'input' is ground truth,it has only one channel,how to transform the channel into channels which are equal to the predicted segmentation image?