SSL4MIS
SSL4MIS copied to clipboard
train_cross_pseudo_supervision.py unnecessary code
Hi,
In the code:
loss1 = 0.5 * (ce_loss(outputs1[:args.labeled_bs],
label_batch[:][:args.labeled_bs].long()) + dice_loss(
outputs_soft1[:args.labeled_bs], label_batch[:args.labeled_bs].unsqueeze(1)))
loss2 = 0.5 * (ce_loss(outputs2[:args.labeled_bs],
label_batch[:][:args.labeled_bs].long()) + dice_loss(
outputs_soft2[:args.labeled_bs], label_batch[:args.labeled_bs].unsqueeze(1)))
for ce_loss, it fetch 'label_batch[:][:args.labeled_bs]' which can be simplified to 'label_batch[:args.labeled_bs]'.
Welcome discussion if there is any mistask.