doc2graph icon indicating copy to clipboard operation
doc2graph copied to clipboard

RuntimeError: weight tensor should be defined either for all 4 classes or no classes but got weight tensor of shape: [3]

Open HonLZL opened this issue 1 year ago • 4 comments

When I set train_batch_size=1, I got this error. It happen in

 n_loss = compute_crossentropy_loss(n_scores.to(device), tg.ndata['label'].to(device)) 
==>  
def compute_crossentropy_loss(scores: torch.Tensor, labels: torch.Tensor):
    w = class_weight.compute_class_weight(class_weight='balanced', classes=np.unique(labels.cpu().numpy()),
                                          y=labels.cpu().numpy())
    return torch.nn.CrossEntropyLoss(weight=torch.tensor(w, dtype=torch.float32).to('cuda:0'))(scores, labels)

RuntimeError: weight tensor should be defined either for all 4 classes or no classes but got weight tensor of shape: [3]

I think it may be because there are only three labels in one image that caused this error. Could you tell me the reason and how to fix it? Thank you so much.

HonLZL avatar Jan 25 '24 12:01 HonLZL

I have the same problem as you, have you fixed this error yet?

TranQuocDat0405 avatar Feb 07 '24 14:02 TranQuocDat0405

I have the same problem as you, have you fixed this error yet?

Sorry, I couldn't find a solution. And you?

HonLZL avatar Feb 29 '24 11:02 HonLZL

I have the same problem as you, have you fixed this error yet?

Hi, I fixed this bug by changing some code.

First: in utils compute_crossentropy_loss, add

    new_labels = labels.cpu().numpy().tolist()

    if len(set(new_labels)) == 3:
        for i in range(4):
            if i not in new_labels:
                w = w.tolist()
                w.insert(i, len(new_labels))
                break

    if len(set(new_labels)) == 1:
        for i in range(2):
            if i not in new_labels:
                w = w.tolist()
                w.insert(i, len(new_labels))
                break

Second: in utils compute_auc_mc function, changing a line

# from
labels = F.one_hot(labels).cpu().numpy()

# to 
labels = F.one_hot(labels, num_classes=2).cpu().numpy()

I hope I can help you

HonLZL avatar Mar 02 '24 09:03 HonLZL

I have the same problem as you, have you fixed this error yet?

Hi, I fixed this bug by changing some code.

First: in utils compute_crossentropy_loss, add

    new_labels = labels.cpu().numpy().tolist()

    if len(set(new_labels)) == 3:
        for i in range(4):
            if i not in new_labels:
                w = w.tolist()
                w.insert(i, len(new_labels))
                break

    if len(set(new_labels)) == 1:
        for i in range(2):
            if i not in new_labels:
                w = w.tolist()
                w.insert(i, len(new_labels))
                break

Second: in utils compute_auc_mc function, changing a line

# from
labels = F.one_hot(labels).cpu().numpy()

# to 
labels = F.one_hot(labels, num_classes=2).cpu().numpy()

I hope I can help you

Thank you so much for your help! I really appreciate it.

TranQuocDat0405 avatar Mar 06 '24 17:03 TranQuocDat0405