PMF icon indicating copy to clipboard operation
PMF copied to clipboard

_computePerceptionAwareLoss remains low (~1e-2) and stops decreasing during training

Open six-wood opened this issue 4 months ago • 0 comments

Problem Description

I'm experiencing an issue with the _computePerceptionAwareLoss function in my multi-modal fusion segmentation model. During training, the perception-aware loss consistently remains around 1e-2 and doesn't decrease further, which suggests there might be numerical issues or implementation problems.

  • The _computePerceptionAwareLoss value stays around 1e-2 throughout training
  • The loss doesn't show a meaningful decrease even when other loss components are improving
  • This behavior persists across different hyperparameter settings

Code Implementation

Here's my current implementation of _computePerceptionAwareLoss:

def _computePerceptionAwareLoss(
    self,
    pcd_entropy: Tensor,
    img_entropy: Tensor,
    pcd_pred: Tensor,
    pcd_pred_log: Tensor,
    img_pred: Tensor,
    img_pred_log: Tensor,
) -> tuple:
    # Calculate confidence (1 - entropy)
    pcd_confidence = 1 - pcd_entropy
    img_confidence = 1 - img_entropy
    
    # Calculate information importance
    information_importance = pcd_confidence - img_confidence
    
    # Create guide masks based on tau threshold
    pcd_guide_mask = pcd_confidence.ge(self.tau).float()  # tau = 0.7
    img_guide_mask = img_confidence.ge(self.tau).float()
    
    # Calculate guide weights
    pcd_guide_weight = (
        information_importance.gt(0).float() * 
        information_importance.abs() * 
        pcd_guide_mask
    )
    img_guide_weight = (
        information_importance.lt(0).float() * 
        information_importance.abs() * 
        img_guide_mask
    )
    
    # Calculate KL divergence losses
    loss_per_pcd = (
        self.kl_loss(pcd_pred_log, img_pred.detach()) * 
        img_guide_weight.unsqueeze(1)
    )
    valid_pcd_count = img_guide_weight.gt(0).sum()
    loss_per_pcd = loss_per_pcd.sum() / valid_pcd_count if valid_pcd_count > 0 else loss_per_pcd.sum()
    
    loss_per_img = (
        self.kl_loss(img_pred_log, pcd_pred.detach()) * 
        pcd_guide_weight.unsqueeze(1)
    )
    valid_img_count = pcd_guide_weight.gt(0).sum()
    loss_per_img = loss_per_img.sum() / valid_img_count if valid_img_count > 0 else loss_per_img.sum()
    
    return loss_per_pcd, loss_per_img, pcd_guide_weight, img_guide_weight

Hyperparameters

  • tau = 0.7 (confidence threshold)
  • gamma = 1.0 (perception loss weight)
  • Using nn.KLDivLoss(reduction="none")

Has anyone encountered similar issues with perception-aware losses in multi-modal settings? Any insights on potential fixes or debugging approaches would be greatly appreciated.

six-wood avatar Aug 22 '25 15:08 six-wood