pytorch-classification-uncertainty
pytorch-classification-uncertainty copied to clipboard
function one_hot_embedding maybe lack `.to(device)`
hi guys, nice work! However, maybe you forget to make torch.eye to device in one_hot_embedding
?
Modified one_hot_embedding
in helpers.py
as
def one_hot_embedding(labels, num_classes=10):
# Convert to One Hot Encoding
device = get_device()
y = torch.eye(num_classes).to(device)
return y[labels]