Dive-into-DL-PyTorch icon indicating copy to clipboard operation
Dive-into-DL-PyTorch copied to clipboard

3.6.6 评价net在data_iter上的准确性

Open CKing111 opened this issue 3 years ago • 0 comments

正常copy,出现下面问题,不知道怎么处理,求帮助

def evaluate_accuracy(data_iter, net): acc_sum, n = 0.0, 0 for X, y in data_iter: acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() n += y.shape[0] return acc_sum / n print(evaluate_accuracy(test_iter, net))


RuntimeError Traceback (most recent call last) in 6 n += y.shape[0] 7 return acc_sum / n ----> 8 print(evaluate_accuracy(test_iter, net))

in evaluate_accuracy(data_iter, net) 3 acc_sum, n = 0.0, 0 4 for X, y in data_iter: ----> 5 acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() 6 n += y.shape[0] 7 return acc_sum / n

in net(X) 1 def net(X): ----> 2 return softmax(torch.mm(X.view((-1, num_inputs)), W) + b)

RuntimeError: The size of tensor a (10) must match the size of tensor b (3) at non-singleton dimension 1 版本信息 pytorch: torchvision: torchtext: ...

CKing111 avatar Nov 28 '20 05:11 CKing111