RuntimeError: weight tensor should be defined either for all 4 classes or no classes but got weight tensor of shape: [3]
When I set train_batch_size=1, I got this error. It happen in
n_loss = compute_crossentropy_loss(n_scores.to(device), tg.ndata['label'].to(device))
==>
def compute_crossentropy_loss(scores: torch.Tensor, labels: torch.Tensor):
w = class_weight.compute_class_weight(class_weight='balanced', classes=np.unique(labels.cpu().numpy()),
y=labels.cpu().numpy())
return torch.nn.CrossEntropyLoss(weight=torch.tensor(w, dtype=torch.float32).to('cuda:0'))(scores, labels)
RuntimeError: weight tensor should be defined either for all 4 classes or no classes but got weight tensor of shape: [3]
I think it may be because there are only three labels in one image that caused this error. Could you tell me the reason and how to fix it? Thank you so much.
I have the same problem as you, have you fixed this error yet?
I have the same problem as you, have you fixed this error yet?
Sorry, I couldn't find a solution. And you?
I have the same problem as you, have you fixed this error yet?
Hi, I fixed this bug by changing some code.
First: in utils compute_crossentropy_loss, add
new_labels = labels.cpu().numpy().tolist()
if len(set(new_labels)) == 3:
for i in range(4):
if i not in new_labels:
w = w.tolist()
w.insert(i, len(new_labels))
break
if len(set(new_labels)) == 1:
for i in range(2):
if i not in new_labels:
w = w.tolist()
w.insert(i, len(new_labels))
break
Second: in utils compute_auc_mc function, changing a line
# from
labels = F.one_hot(labels).cpu().numpy()
# to
labels = F.one_hot(labels, num_classes=2).cpu().numpy()
I hope I can help you
I have the same problem as you, have you fixed this error yet?
Hi, I fixed this bug by changing some code.
First: in utils compute_crossentropy_loss, add
new_labels = labels.cpu().numpy().tolist() if len(set(new_labels)) == 3: for i in range(4): if i not in new_labels: w = w.tolist() w.insert(i, len(new_labels)) break if len(set(new_labels)) == 1: for i in range(2): if i not in new_labels: w = w.tolist() w.insert(i, len(new_labels)) breakSecond: in utils compute_auc_mc function, changing a line
# from labels = F.one_hot(labels).cpu().numpy() # to labels = F.one_hot(labels, num_classes=2).cpu().numpy()I hope I can help you
Thank you so much for your help! I really appreciate it.