djl icon indicating copy to clipboard operation
djl copied to clipboard

How works the function EasyTrain.evaluateDataset() ?

Open MohamedLEGH opened this issue 10 months ago • 2 comments

Description

When I run EasyTrain.fit(trainer, numEpochs, trainingSet, null); I get the metrics in the standard output like this:

Training:    100% |████████████████████████████████████████| Accuracy: 1,00, SigmoidBinaryCrossEntropyLoss: 121,25
Validating:  100% |████████████████████████████████████████|

But when I run EasyTrain.evaluateDataset(trainer, testSet); nothing is printed on the standard output and the function returns void. How can I check the results of the evaluation?

My code

TrainingConfig config = new DefaultTrainingConfig(loss)
    .optOptimizer(sgd) // Optimizer
    .optDevices(manager.getEngine().getDevices(0)) // CPU
    .addEvaluator(new Accuracy()) // Model Accuracy
    .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging

Trainer trainer = model.newTrainer(config);

trainer.initialize(new Shape(batchSize, 57)); 
Metrics metrics = new Metrics();
trainer.setMetrics(metrics);

int numEpochs = 1;

EasyTrain.fit(trainer, numEpochs, trainingSet, null);
EasyTrain.evaluateDataset(trainer, testSet);

MohamedLEGH avatar Apr 23 '24 01:04 MohamedLEGH

The evaluateDataset is automatically run during EasyTrain.fit() at the end of each epoch. It is designed to collect the metrics to see how the validation evaluation changes. You can view those results using trainer.getTrainingResult()

zachgk avatar Apr 23 '24 16:04 zachgk

If the function is not used outside of EasyTrain.fit(), maybe the method should be private or protected ?

MohamedLEGH avatar Apr 24 '24 02:04 MohamedLEGH