pytorch-unsupervised-segmentation-tip
pytorch-unsupervised-segmentation-tip copied to clipboard
loss_fn_scr Errors
Traceback (most recent call last):
File "demo.py", line 145, in
I got the error on the same line, but it was: RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #2 'target'
The reason was in target_scr[inds_scr]. In the initial step, it has type int32, but should have int64. So I changed the code into:
loss = args.stepsize_sim * loss_fn(output[inds_sim], target[inds_sim]) + \
args.stepsize_scr * loss_fn_scr(output[inds_scr], target_scr[inds_scr].type(torch.int64)) + \
args.stepsize_con * (lhpy + lhpz)
There is a simple solution.
loss = args.stepsize_sim * loss_fn(output[ inds_sim ], target[ inds_sim ]) + \
args.stepsize_scr * loss_fn_scr(output[ inds_scr ], target_scr[ inds_scr ].long()) + \
args.stepsize_con * (lhpy + lhpz)