pytorch-classification-uncertainty icon indicating copy to clipboard operation
pytorch-classification-uncertainty copied to clipboard

function one_hot_embedding maybe lack `.to(device)`

Open LSTM-Kirigaya opened this issue 1 year ago • 0 comments

hi guys, nice work! However, maybe you forget to make torch.eye to device in one_hot_embedding?

Modified one_hot_embedding in helpers.py as

def one_hot_embedding(labels, num_classes=10):
    # Convert to One Hot Encoding
    device = get_device()
    y = torch.eye(num_classes).to(device)
    return y[labels]

LSTM-Kirigaya avatar Jan 01 '24 17:01 LSTM-Kirigaya