BalancedMSE
BalancedMSE copied to clipboard
About the "bmc_loss_md"
Thanks for your work.
def bmc_loss_md(pred, target, noise_var): I = torch.eye(pred.shape[-1]) logits = MVN(pred.unsqueeze(1), noise_var*I).log_prob(target.unsqueeze(0)) loss = F.cross_entropy(logits, torch.arange(pred.shape[0])) loss = loss * (2 * noise_var).detach() return loss
My size of pred and target are [30,3,256,256] when running this code "loss = F.cross_entropy(logits, torch.arange(pred.shape[0]))",i got a error because "torch.arange(pred.shape[0])"is 1d,and logits is 4d.
How can i solve this error
You may resize both pred & target to 2D tensors with size [30,3x256x256].
I have tried.But it will out of memory when creating diagonal matrix. I run this code "pred=pred.reshape(pred.shpe[0],-1)" ,the shape of output is[30,3x256x256].
I have not tested balanced mse on image reconstruction but you may try downsampling the image to a smaller resolution, e.g., 32x32. Sorry for the late reply.