coral-pytorch
coral-pytorch copied to clipboard
How can I add the importance_weights of each class to the corn loss?
Hi,
Thanks for sharing the code.
I noticed that a importance_weights of the coral loss. Could I add the importance_weights of each class to the corn loss?
Many thanks, Tien-Hong
Yes, they could be added. We omitted them for simplicity in the CORN paper.
Thanks for your kind reply.
I haven't run the code if the shape of the importance_weights is (#NUM_CLASS, 1), Is the following modified code (#comment) correct?
def corn_loss(logits, y_train, num_classes, importance_weights):
sets = []
for i in range(num_classes-1):
label_mask = y_train > i-1
label_tensor = (y_train[label_mask] > i).to(torch.int64)
sets.append((label_mask, label_tensor))
num_examples = 0
losses = 0.
for task_index, s in enumerate(sets):
train_examples = s[0]
train_labels = s[1]
if len(train_labels) < 1:
continue
num_examples += len(train_labels)
pred = logits[train_examples, task_index]
loss = -torch.sum(F.logsigmoid(pred)*train_labels
+ (F.logsigmoid(pred) - pred)*(1-train_labels))
#losses += loss
losses += importance_weights[task_index] * loss
return losses/num_examples
Yes, this looks correct to me. You can also add a default argument so that it performs like before if someone doesn't specify the importance weights:
def corn_loss(logits, y_train, num_classes, importance_weights=None):
sets = []
for i in range(num_classes-1):
label_mask = y_train > i-1
label_tensor = (y_train[label_mask] > i).to(torch.int64)
sets.append((label_mask, label_tensor))
num_examples = 0
losses = 0.
if importance_weights is None:
importance_weights = torch.ones(len(sets))
for task_index, s in enumerate(sets):
train_examples = s[0]
train_labels = s[1]
if len(train_labels) < 1:
continue
num_examples += len(train_labels)
pred = logits[train_examples, task_index]
loss = -torch.sum(F.logsigmoid(pred)*train_labels
+ (F.logsigmoid(pred) - pred)*(1-train_labels))
#losses += loss
losses += importance_weights[task_index] * loss
return losses/num_examples