pytorch-lanenet icon indicating copy to clipboard operation
pytorch-lanenet copied to clipboard

Discriminative Loss can not calculate dist loss...

Open IrohXu opened this issue 3 years ago • 3 comments

During training, I found the instance part of your program is not accurate, the instance loss is too fast to converge to 0. Thus, I checked the loss.py file of your code, and found that parameter num_lanes inside class DiscriminativeLoss is always = 1 during training. This might be caused by your one-hot representation for instance target.

Due to num_lanes = 1, this part will miss:

if num_lanes > 1:
  centroid_mean1 = centroid_mean.reshape(-1, 1, num_lanes)
  centroid_mean2 = centroid_mean.reshape(1, -1, num_lanes)
  dist = torch.norm(centroid_mean1 - centroid_mean2, dim=2)  # shape (num_lanes, num_lanes)
  dist = dist + torch.eye(embed_dim, dtype=dist.dtype, device=dist.device) * self.delta_dist  # diagonal elements are 0, now mask above delta_d
  dist_loss = dist_loss + torch.sum(F.relu(-dist + self.delta_dist) ** 2) / (
          num_lanes * (num_lanes - 1)) / 2

#################################################

To solve this problem, I think we need to fix the DiscriminativeLoss or build a new dataloader. I change a little in DiscriminativeLoss and this part may works (I am not sure):

class DiscriminativeLoss(_Loss):

    def __init__(self, delta_var=0.5, delta_dist=1.5, norm=2, alpha=1.0, beta=1.0, gamma=0.001,
                 usegpu=False, size_average=True):
        super(DiscriminativeLoss, self).__init__(reduction='mean')
        self.delta_var = delta_var
        self.delta_dist = delta_dist
        self.norm = norm
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.usegpu = usegpu
        assert self.norm in [1, 2]

    def forward(self, input, target):

        return self._discriminative_loss(input, target)

    def _discriminative_loss(self, embedding, seg_gt):
        batch_size, embed_dim, H, W = embedding.shape
        embedding = embedding.reshape(batch_size, embed_dim, H*W)
        seg_gt = seg_gt.reshape(batch_size, embed_dim, H*W)

        var_loss = torch.tensor(0, dtype=embedding.dtype, device=embedding.device)
        dist_loss = torch.tensor(0, dtype=embedding.dtype, device=embedding.device)
        reg_loss = torch.tensor(0, dtype=embedding.dtype, device=embedding.device)

        for b in range(batch_size):
            embedding_b = embedding[b]  # (embed_dim, H*W)
            seg_gt_b = torch.zeros((H * W)).to(DEVICE)

            for j in range(0, embed_dim):
                seg_gt_b += seg_gt[b][j] * (j+1)

            labels, indexs = torch.unique(seg_gt_b, return_inverse=True)
            num_lanes = len(labels)
            if num_lanes == 0:
                # please refer to issue here: https://github.com/harryhan618/LaneNet/issues/12
                _nonsense = embedding.sum()
                _zero = torch.zeros_like(_nonsense)
                var_loss = var_loss + _nonsense * _zero
                dist_loss = dist_loss + _nonsense * _zero
                reg_loss = reg_loss + _nonsense * _zero
                continue

            centroid_mean = []
            for lane_idx in labels:
                seg_mask_i = (seg_gt_b == lane_idx)

                if not seg_mask_i.any():
                    continue
                
                embedding_i = embedding_b * seg_mask_i
                mean_i = torch.sum(embedding_i, dim=1) / torch.sum(seg_mask_i)

                centroid_mean.append(mean_i)

                # ---------- var_loss -------------
                var_loss = var_loss + torch.sum(F.relu(
                    torch.norm(embedding_i[:,seg_mask_i] - mean_i.reshape(embed_dim, 1), dim=0) - self.delta_var) ** 2) / torch.sum(seg_mask_i) / num_lanes
            centroid_mean = torch.stack(centroid_mean)  # (n_lane, embed_dim)

            if num_lanes > 1:
                centroid_mean1 = centroid_mean.reshape(-1, 1, embed_dim)
                centroid_mean2 = centroid_mean.reshape(1, -1, embed_dim)

                dist = torch.norm(centroid_mean1 - centroid_mean2, dim=2)  # shape (num_lanes, num_lanes)
                dist = dist + torch.eye(num_lanes, dtype=dist.dtype,
                                        device=dist.device) * self.delta_dist  # diagonal elements are 0, now mask above delta_d

                # divided by two for double calculated loss above, for implementation convenience
                dist_loss = dist_loss + torch.sum(F.relu(-dist + self.delta_dist) ** 2) / (
                        num_lanes * (num_lanes - 1)) / 2

            # reg_loss is not used in original paper
            # reg_loss = reg_loss + torch.mean(torch.norm(centroid_mean, dim=1))

        var_loss = var_loss / batch_size
        dist_loss = dist_loss / batch_size
        reg_loss = reg_loss / batch_size

        return var_loss, dist_loss, reg_loss

Who can help me test it? Thanks a lot.

IrohXu avatar Jun 18 '21 08:06 IrohXu

RuntimeError: shape '[4, 131072]' is invalid for input of size 1572864 when ı try to my own dataset, which is same as with your dataset structure, ı am getting this error: lanenet-lane-detection-pytorch/model/lanenet/loss.py", line 64, in _discriminative_loss seg_gt = seg_gt.reshape(batch_size, H*W) on loss.py. How can ı fix ?

Rakuzan-Developer avatar Feb 24 '22 09:02 Rakuzan-Developer

Thanks for this ( unfortunately I missed it last year :O ) , @IrohXu would be awesome if you create a PR for this and I'll review it.

klintan avatar Feb 24 '22 18:02 klintan

@IrohXu what's means for

for j in range(0, embed_dim):
                seg_gt_b += seg_gt[b][j] * (j+1)

in your code?

qq852518421 avatar Mar 01 '22 14:03 qq852518421