Dive-into-DL-PyTorch
Dive-into-DL-PyTorch copied to clipboard
3.6.6 评价net在data_iter上的准确性
正常copy,出现下面问题,不知道怎么处理,求帮助
def evaluate_accuracy(data_iter, net): acc_sum, n = 0.0, 0 for X, y in data_iter: acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() n += y.shape[0] return acc_sum / n print(evaluate_accuracy(test_iter, net))
RuntimeError Traceback (most recent call last)
RuntimeError: The size of tensor a (10) must match the size of tensor b (3) at non-singleton dimension 1 版本信息 pytorch: torchvision: torchtext: ...