DistributionLoss
DistributionLoss copied to clipboard
Question about distribution loss
Hi Ruizhou,
Thanks for sharing your code! When I read your code, there are some problems that bother me. I cannot understand the codes for distribution loss, because they are inconsistent with the description in the paper. In addition, the hyper-parameter is also inconsistent with that in the paper.
Cloud you help me and make some explanations about the distribution loss in the code?
Thanks!
Hi,
Thanks for the question. Quick answer - the code is mainly for ImageNet. Due to history reasons, we used a variant of the distribution loss formulation. They underlying intuition is the same - to reduce the three training-induced issues described in the paper.
If you do experiment on CIFAR-10/CIFAR-100/SVHN, please use the same setting as described in Sec. 4.1 Training Configuration.
# # For ImageNet
# distrloss1 = (torch.min(2 - mean - std,
# 2 + mean - std).clamp(min=0) ** 2).mean() + \
# ((std - 4).clamp(min=0) ** 2).mean()
# distrloss2 = (mean ** 2 - std ** 2).clamp(min=0).mean()
# # For CIFAR-10/CIFAR-100/SVHN
# distrloss1 = (torch.min(1 - mean - 0.25*std,
# 1 + mean - 0.25*std).clamp(min=0) ** 2).mean() + \
# ((std - 4).clamp(min=0) ** 2).mean()
# distrloss2 = (mean ** 2 - std ** 2).clamp(min=0).mean()
Btw, can you leave the issue unclosed until I find a time to update the code? I guess others may have similar questions as well..
Thanks, Ruizhou
((std - 4).clamp(min=0) ** 2).mean() shouldn't this be ((0.25*std - 1).clamp(min=0) ** 2).mean()? The former is basically around 16 times the latter because the thing that your squaring is multiplied by 4 in the former. Any clarification would be helpful.
Thanks!