djl
djl copied to clipboard
How works the function EasyTrain.evaluateDataset() ?
Description
When I run EasyTrain.fit(trainer, numEpochs, trainingSet, null);
I get the metrics in the standard output like this:
Training: 100% |████████████████████████████████████████| Accuracy: 1,00, SigmoidBinaryCrossEntropyLoss: 121,25
Validating: 100% |████████████████████████████████████████|
But when I run EasyTrain.evaluateDataset(trainer, testSet);
nothing is printed on the standard output and the function returns void.
How can I check the results of the evaluation?
My code
TrainingConfig config = new DefaultTrainingConfig(loss)
.optOptimizer(sgd) // Optimizer
.optDevices(manager.getEngine().getDevices(0)) // CPU
.addEvaluator(new Accuracy()) // Model Accuracy
.addTrainingListeners(TrainingListener.Defaults.logging()); // Logging
Trainer trainer = model.newTrainer(config);
trainer.initialize(new Shape(batchSize, 57));
Metrics metrics = new Metrics();
trainer.setMetrics(metrics);
int numEpochs = 1;
EasyTrain.fit(trainer, numEpochs, trainingSet, null);
EasyTrain.evaluateDataset(trainer, testSet);
The evaluateDataset is automatically run during EasyTrain.fit()
at the end of each epoch. It is designed to collect the metrics to see how the validation evaluation changes. You can view those results using trainer.getTrainingResult()
If the function is not used outside of EasyTrain.fit()
, maybe the method should be private
or protected
?