APro icon indicating copy to clipboard operation
APro copied to clipboard

Help on the weight of aproloss and partial cross entropy loss?

Open lauraset opened this issue 2 months ago • 0 comments

Hi, @CircleRadon. Thank you for your great work. I am not clear about the weight of aproloss and its implementations. According to the issue 3, the implementation of aproloss is:

class AproLoss(nn.Module):
    def __init__(self, ignore_index=255):
        super().__init__()
        # partial cross entropy
        self.partialCE = nn.CrossEntropyLoss(ignore_index=ignore_index)
        # apro
        self.global_apro = Global_APro()
        self.local_apro = Local_APro(kernel_size=5, zeta_s=0.15) #set kernel_size and zeta_s
        self.mst = MinimumSpanningTree(Global_APro.norm2_distance)
        # pca n_component
        # self.q = 1
        self.ignore_index = ignore_index

   def forward(self, x, y_hat, y):
        # x: B, C, H, W
        # y_hat B, classes, H, W
        # partial cross entropy
        partial = self.partialCE(y_hat, y)
        # compute PCA
        # B, 1, H, W
        # pca_imgs = self.compute_pca(x)

        # compute image tree
        # I think directly using x is also fine
        img_mst_tree = self.mst(x)
        # img_mst_tree = self.mst(pca_imgs)

        # y: B, H, W
        # y = y.float()
        y_hat = torch.softmax(y_hat, dim=1) # convert to probability [0,1]

        # psuedo label for global info
        # using low level feature
        soft_pseudo = self.global_apro(y_hat, x, img_mst_tree, zeta_g=0.001)
        # using deep feature
        soft_pseudo = self.global_apro(soft_pseudo, y_hat, img_mst_tree, zeta_g=0.05)

        # unlabelled region only
        unlabelled_regions = (y.unsqueeze(1) == self.ignore_index)

        # compute difference between generated psuedo labels and predicted one
        loss_global_term = torch.abs(soft_pseudo-y_hat) * unlabelled_regions
        # normalize the loss
        n_regions = unlabelled_regions.sum().clamp(min=1)
        loss_global = loss_global_term.sum() / n_regions

        # local term
        soft_pseudo = self.local_apro(pca_imgs, y_hat)
        loss_local_term = torch.abs(y_hat - soft_pseudo) * unlabelled_regions
        loss_local_term = loss_local_term.sum() / unlabelled_regions.sum().clamp(min=1)
        loss_local = loss_local_term
        return partial + loss_global + loss_local

I have several questions:

  1. How to set the weight of partial cross entropy and global/local apro loss?
  2. For global apro, the deep feature is directly set to y_hat. Is this the defualt setting in your paper? Why it should be set as the last feature map from the segmentation network?

Thank you in advance.

lauraset avatar Apr 22 '24 06:04 lauraset