PMF
PMF copied to clipboard
_computePerceptionAwareLoss remains low (~1e-2) and stops decreasing during training
Problem Description
I'm experiencing an issue with the _computePerceptionAwareLoss function in my multi-modal fusion segmentation model. During training, the perception-aware loss consistently remains around 1e-2 and doesn't decrease further, which suggests there might be numerical issues or implementation problems.
- The
_computePerceptionAwareLossvalue stays around1e-2throughout training - The loss doesn't show a meaningful decrease even when other loss components are improving
- This behavior persists across different hyperparameter settings
Code Implementation
Here's my current implementation of _computePerceptionAwareLoss:
def _computePerceptionAwareLoss(
self,
pcd_entropy: Tensor,
img_entropy: Tensor,
pcd_pred: Tensor,
pcd_pred_log: Tensor,
img_pred: Tensor,
img_pred_log: Tensor,
) -> tuple:
# Calculate confidence (1 - entropy)
pcd_confidence = 1 - pcd_entropy
img_confidence = 1 - img_entropy
# Calculate information importance
information_importance = pcd_confidence - img_confidence
# Create guide masks based on tau threshold
pcd_guide_mask = pcd_confidence.ge(self.tau).float() # tau = 0.7
img_guide_mask = img_confidence.ge(self.tau).float()
# Calculate guide weights
pcd_guide_weight = (
information_importance.gt(0).float() *
information_importance.abs() *
pcd_guide_mask
)
img_guide_weight = (
information_importance.lt(0).float() *
information_importance.abs() *
img_guide_mask
)
# Calculate KL divergence losses
loss_per_pcd = (
self.kl_loss(pcd_pred_log, img_pred.detach()) *
img_guide_weight.unsqueeze(1)
)
valid_pcd_count = img_guide_weight.gt(0).sum()
loss_per_pcd = loss_per_pcd.sum() / valid_pcd_count if valid_pcd_count > 0 else loss_per_pcd.sum()
loss_per_img = (
self.kl_loss(img_pred_log, pcd_pred.detach()) *
pcd_guide_weight.unsqueeze(1)
)
valid_img_count = pcd_guide_weight.gt(0).sum()
loss_per_img = loss_per_img.sum() / valid_img_count if valid_img_count > 0 else loss_per_img.sum()
return loss_per_pcd, loss_per_img, pcd_guide_weight, img_guide_weight
Hyperparameters
tau = 0.7(confidence threshold)gamma = 1.0(perception loss weight)- Using
nn.KLDivLoss(reduction="none")
Has anyone encountered similar issues with perception-aware losses in multi-modal settings? Any insights on potential fixes or debugging approaches would be greatly appreciated.