Transformer-OCR icon indicating copy to clipboard operation
Transformer-OCR copied to clipboard

@delveintodetail Is there wrong in the predict.py file?

Open delveintodetail opened this issue 4 years ago • 2 comments

@delveintodetail Is there wrong in the predict.py file?

Originally posted by @li10141110 in https://github.com/fengxinjie/Transformer-OCR/issues/4#issuecomment-607083558

for epoch in range(10000): model.train() run_epoch(train_dataloader, model, SimpleLossCompute(model.generator, criterion, model_opt)) model.eval() test_loss = run_epoch(val_dataloader, model, SimpleLossCompute(model.generator, criterion, None)) print("test_loss", test_loss) torch.save(model.state_dict(), 'checkpoint/%08d_%f.pth'%(epoch, test_loss))

the evaluation should not be different from training, but in this implementation, he uses the same method.

delveintodetail avatar Apr 01 '20 08:04 delveintodetail

@delveintodetail not clear what you are presenting.

In predict.py file, there is a model.eval() and there is the same in train.py:

    for epoch in range(10000):
        model.train()
        run_epoch(train_dataloader, model, 
              SimpleLossCompute(model.generator, criterion, model_opt))
        model.eval()
        test_loss = run_epoch(val_dataloader, model, 
              SimpleLossCompute(model.generator, criterion, None))
        print("test_loss", test_loss)
        torch.save(model.state_dict(), 'checkpoint/%08d_%f.pth'%(epoch, test_loss))

and you are stating, I guess, that model.eval() should be different (not "should not be different" (?)) in the train.py and predict.py, but that here they are the same.

Why should they be different?

gussmith avatar Apr 09 '20 20:04 gussmith

I guess @delveintodetail means teacher forcing. However, in predict.py teacher forcing is not adopted, so I don't think there are bugs in predict.py

Pay20Y avatar Apr 11 '20 08:04 Pay20Y