coral-pytorch icon indicating copy to clipboard operation
coral-pytorch copied to clipboard

How can I add the importance_weights of each class to the corn loss?

Open teinhonglo opened this issue 1 year ago • 3 comments

Hi,

Thanks for sharing the code.

I noticed that a importance_weights of the coral loss. Could I add the importance_weights of each class to the corn loss?

Many thanks, Tien-Hong

teinhonglo avatar Aug 04 '23 08:08 teinhonglo

Yes, they could be added. We omitted them for simplicity in the CORN paper.

rasbt avatar Aug 04 '23 11:08 rasbt

Thanks for your kind reply.

I haven't run the code if the shape of the importance_weights is (#NUM_CLASS, 1), Is the following modified code (#comment) correct?

def corn_loss(logits, y_train, num_classes, importance_weights):
    sets = []
    for i in range(num_classes-1):
        label_mask = y_train > i-1
        label_tensor = (y_train[label_mask] > i).to(torch.int64)
        sets.append((label_mask, label_tensor))

    num_examples = 0
    losses = 0.
    for task_index, s in enumerate(sets):
        train_examples = s[0]
        train_labels = s[1]

        if len(train_labels) < 1:
            continue

        num_examples += len(train_labels)
        pred = logits[train_examples, task_index]

        loss = -torch.sum(F.logsigmoid(pred)*train_labels
                          + (F.logsigmoid(pred) - pred)*(1-train_labels))
        
        #losses += loss
        losses += importance_weights[task_index] * loss

    return losses/num_examples

teinhonglo avatar Aug 04 '23 14:08 teinhonglo

Yes, this looks correct to me. You can also add a default argument so that it performs like before if someone doesn't specify the importance weights:

def corn_loss(logits, y_train, num_classes, importance_weights=None):
    sets = []
    
    for i in range(num_classes-1):
        label_mask = y_train > i-1
        label_tensor = (y_train[label_mask] > i).to(torch.int64)
        sets.append((label_mask, label_tensor))

    num_examples = 0
    losses = 0.
    
    if importance_weights is None:
        importance_weights = torch.ones(len(sets))
    
    for task_index, s in enumerate(sets):
        train_examples = s[0]
        train_labels = s[1]

        if len(train_labels) < 1:
            continue

        num_examples += len(train_labels)
        pred = logits[train_examples, task_index]

        loss = -torch.sum(F.logsigmoid(pred)*train_labels
                          + (F.logsigmoid(pred) - pred)*(1-train_labels))
        
        #losses += loss
        losses += importance_weights[task_index] * loss

    return losses/num_examples

rasbt avatar Aug 07 '23 22:08 rasbt