Class-balanced-loss-pytorch icon indicating copy to clipboard operation
Class-balanced-loss-pytorch copied to clipboard

why sum weights?

Open mmxuan18 opened this issue 4 years ago • 2 comments

why add this sum make all class weight the same image

mmxuan18 avatar Jan 27 '21 08:01 mmxuan18

The construct above until the highlighted line is for selecting the weight per each label. For example, if your target label is (one-hot-encoded) class 1, you would use those lines to select the weight corresponding to that index. One could simplify the lines

weights = torch.tensor(weights).float()
weights = weights.unsqueeze(0)
weights = weights.repeat(labels_one_hot.shape[0],1) * labels_one_hot
weights = weights.sum(1)

by simply saying

weights = weights[labels_one_hot.argmax(1)]  # pick the correct weight for each label
# weights = weights[labels.long()]  # this would also do

Here's a snippet you can use to verify:

# gen dummy labels
num_classes = 5
labels = torch.randint(num_classes, (100, ))
labels_one_hot = torch.eye(num_classes)[labels]
# gen dummy weights
weights = torch.randn(num_classes)
# original method
weights0 = weights.unsqueeze(0)
weights0 = weights0.repeat(labels_one_hot.shape[0],1) * labels_one_hot
weights0 = weights0.sum(1)
# method 1
weights1 = weights[labels_one_hot.argmax(1)]
# 
weights2 = weights[labels.long()]
print(torch.equal(weights0, weights1), torch.equal(weights0, weights2))
>> True True

mjkvaak avatar Nov 18 '22 12:11 mjkvaak