djl icon indicating copy to clipboard operation
djl copied to clipboard

Early stopping configuration

Open jSaso opened this issue 4 years ago • 3 comments


Early stopping configuration: Specifies the various configuration options for running training with early stopping.

  • early stopping model saver - only use last best model: How model will be saved (to disk, to memory, etc)
  • Termination conditions: 1. Iteration termination conditions: how many epoch till termination. 2. score improvement termination condition - terminate training if best model score does not improve for N epochs 3. best expected score - terminate training once we achieved an expected score. 4. termination condition after certain time - terminate training after certain time 5. other termination conditions, if they are logical

Will this change the current api? How?

We can configure when model training will stop, when one of condition above is met. Training should be implemented as listener, early stop configuration will listen for any conditions above and terminate training.

Who will benefit from this feature?

Everybody, we can easily configure when learning will end.


Reference implementation: There are other implementation in different NN framework.

jSaso avatar Mar 19 '20 21:03 jSaso

This is pretty important. I'd like to see three criteria:

  1. minimum number of epochs (e.g. 2), no matter what.
  2. stop if the validation set doesn't improve for earlyStopPatience (e.g. 3) epochs
  3. if the user sends a SIGINT (or another) process signal, we should stop at tne end of the current epoch. This is different than a SIGKILL signal, which kills the process.

gforman44 avatar May 12 '22 17:05 gforman44

Here's a proposal for what I'd like to see:

Parameters for flexible stopping criteria:

    static int maxEpochs = 1000;
    static double objectiveSuccess = 0.5;// done if validation loss objective (e.g. L2Loss) < threshold
    static int minEpochs = 2;// after minimum # epochs, consider stopping if:
    static int maxMinutes = 5*60;// too much time elapsed
    static double earlyStopPctImprovement = 2;// consider early stopping if not 2% improvement
    static int earlyStopPatience = 3;// stop if insufficient improvement for 3 epochs in a row

With these parameters, then you can implement it like this in

    public static void fit(Trainer trainer, RandomAccessDataset trainingSet, RandomAccessDataset validateSet) throws TranslateException, IOException {
        final long start = System.currentTimeMillis();
        double prevLoss = Double.NaN;
        int improvementFailures = 0;
        for (int epoch = 0; epoch < maxEpochs; epoch++) {
            for (Batch batch: trainer.iterateDataset(trainingSet)) {
                EasyTrain.trainBatch(trainer, batch);

            // After each epoch, test against the validation dataset if we have one
            EasyTrain.evaluateDataset(trainer, validateSet);

            // reset training and validation evaluators at end of epoch
            trainer.notifyListeners(listener -> listener.onEpoch(trainer));

            // stopping criteria
            final double vloss = trainer.getTrainingResult().getValidateLoss();// else use train loss if no validation set
            if (vloss < objectiveSuccess) {
                System.out.printf("END: validation loss %s < objectiveSuccess %s\n", vloss, objectiveSuccess);
            if (epoch+1 >= minEpochs) {
                double elapsedMinutes = (System.currentTimeMillis() - start) / 60_000.0;
                if (elapsedMinutes >= maxMinutes) {
                    System.out.printf("END: %.1f minutes elapsed >= %s maxMinutes\n", elapsedMinutes, maxMinutes);
                // consider early stopping?
                if (Double.isFinite(prevLoss)) {
                    double goalImprovement = prevLoss * (100 - earlyStopPctImprovement) / 100.0;
                    boolean improved = vloss <= goalImprovement;// false if any NANs
                    if (improved) {
                        improvementFailures = 0;
                    } else {
                        if (improvementFailures >= earlyStopPatience) {
                            System.out.printf("END: failed to achieve %s%% improvement %s times in a row\n",
                                    earlyStopPctImprovement, earlyStopPatience);
            if (Double.isFinite(vloss)) {
                prevLoss = vloss;

gforman44 avatar May 13 '22 06:05 gforman44

@gforman44 That looks pretty good. One thing I was thinking was that we could implement the early stopping with a TrainingListener. That would give a good place to add in the early stopping configuration and helps manage all the different pieces of functionality that users may or may not want as part of their training. It could throw an EarlyStopException if it decides to end the training early.

Anyway, it sounds like you are really interested in this issue @gforman44. Do you want to implement it and submit a PR?

zachgk avatar May 13 '22 17:05 zachgk