MPANet icon indicating copy to clipboard operation
MPANet copied to clipboard

Issue about addition of the 'v_cls_loss' and 'i_cls_loss'

Open workingcoder opened this issue 3 years ago • 1 comments

In the file 'MPANet/models/baseline.py', the 'train_forward' function calculates the final loss composed of multiple components. In the following code snippet, I notice that 'v_cls_loss' and 'i_cls_loss' are added twice.

        if self.mutual_learning:
            # cam_ids = kwargs.get('cam_ids')
            # sub = (cam_ids == 3) + (cam_ids == 6)
            
            logits_v = self.visible_classifier(feat[sub == 0])
            v_cls_loss = self.id_loss(logits_v.float(), labels[sub == 0])
            **loss += v_cls_loss * self.weight_sid**
            logits_i = self.infrared_classifier(feat[sub == 1])
            i_cls_loss = self.id_loss(logits_i.float(), labels[sub == 1])
            **loss += i_cls_loss * self.weight_sid**

            logits_m = torch.cat([logits_v, logits_i], 0).float()
            with torch.no_grad():
                self.infrared_classifier_.weight.data = self.infrared_classifier_.weight.data * (1 - self.update_rate) \
                                                 + self.infrared_classifier.weight.data * self.update_rate
                self.visible_classifier_.weight.data = self.visible_classifier_.weight.data * (1 - self.update_rate) \
                                                 + self.visible_classifier.weight.data * self.update_rate

                logits_v_ = self.infrared_classifier_(feat[sub == 0])
                logits_i_ = self.visible_classifier_(feat[sub == 1])

                logits_m_ = torch.cat([logits_v_, logits_i_], 0).float()
            logits_m = F.softmax(logits_m, 1)
            logits_m_ = F.log_softmax(logits_m_, 1)
            mod_loss = self.KLDivLoss(logits_m_, logits_m) 

            **loss += mod_loss * self.weight_KL + (v_cls_loss + i_cls_loss) * self.weight_sid**
            metric.update({'ce-v': v_cls_loss.data})
            metric.update({'ce-i': i_cls_loss.data})
            metric.update({'KL': mod_loss.data})

Did you do it on purpose with double 'self.weight_sid'?

workingcoder avatar Oct 30 '21 08:10 workingcoder

It is my fault, and all my experiments according to this code.

DoubtedSteam avatar Nov 13 '21 08:11 DoubtedSteam