SSL4MIS icon indicating copy to clipboard operation
SSL4MIS copied to clipboard

train_cross_pseudo_supervision.py unnecessary code

Open IcecreamArtist opened this issue 1 year ago • 0 comments

Hi,

In the code:

loss1 = 0.5 * (ce_loss(outputs1[:args.labeled_bs],
                                   label_batch[:][:args.labeled_bs].long()) + dice_loss(
                outputs_soft1[:args.labeled_bs], label_batch[:args.labeled_bs].unsqueeze(1)))
            loss2 = 0.5 * (ce_loss(outputs2[:args.labeled_bs],
                                   label_batch[:][:args.labeled_bs].long()) + dice_loss(
                outputs_soft2[:args.labeled_bs], label_batch[:args.labeled_bs].unsqueeze(1)))

for ce_loss, it fetch 'label_batch[:][:args.labeled_bs]' which can be simplified to 'label_batch[:args.labeled_bs]'.

Welcome discussion if there is any mistask.

IcecreamArtist avatar Oct 17 '23 07:10 IcecreamArtist