d2l-en
d2l-en copied to clipboard
Errors in train_ch3 in tensorflow version softmax-regression-scratch.ipynb
Here is the function that causes an error when I use train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)
in the softmax-regression-scratch.ipynb
def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):
"""动画+训练模型(定义见第3章)"""
animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],
legend=['train loss', 'train acc', 'test acc'])
for epoch in range(num_epochs):
train_metrics = train_epoch_ch3(net, train_iter, loss, updater)
##NameError: name 'train_epoch_ch3' is not defined
test_acc = evaluate_accuracy(net, test_iter)
animator.add(epoch + 1, train_metrics + (test_acc,))
train_loss, train_acc = train_metrics
assert train_loss < 0.5, train_loss
assert train_acc <= 1 and train_acc > 0.7, train_acc
assert test_acc <= 1 and test_acc > 0.7, test_acc
Here is the bug details
NameError Traceback (most recent call last)
NameError: name 'train_epoch_ch3' is not defined
Here is another bug in the last cell of softmax-regression-concise.ipynb d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
Here is the bug detail
AssertionError Traceback (most recent call last)
/usr/local/lib/python3.10/dist-packages/d2l/tensorflow.py in train_ch3(net, train_iter, test_iter, loss, num_epochs, updater) 318 animator.add(epoch + 1, train_metrics + (test_acc,)) 319 train_loss, train_acc = train_metrics --> 320 assert train_loss < 0.5, train_loss 321 assert train_acc <= 1 and train_acc > 0.7, train_acc 322 assert test_acc <= 1 and test_acc > 0.7, test_acc
AssertionError: 0.5710129758199056
ImportError Traceback (most recent call last) /usr/local/lib/python3.10/dist-packages/IPython/core/formatters.py in call(self, obj) 339 pass 340 else: --> 341 return printer(obj) 342 # Finally look for special method names 343 method = get_real_method(obj, self.print_method)
12 frames
/usr/local/lib/python3.10/dist-packages/matplotlib/backends/backend_svg.py in
ImportError: cannot import name '_check_savefig_extra_args' from 'matplotlib.backend_bases' (/usr/local/lib/python3.10/dist-packages/matplotlib/backend_bases.py)