early-stopping-pytorch
early-stopping-pytorch copied to clipboard
EarlyStopping for KFold Cross-validation
I've been using this EarlyStopping class for my master's thesis project, but I had to modify it slightly to save the model with the lowest val_loss for a given fold when doing kfold cross-validation. So, it works the same way as the original class, the difference is that it monitors the val_loss for each of the folds and saves k different models (one per fold), on a given path. On each fold, the object is reset to monitor the val_loss starting from inf-->fold_val_loss and it creates the filename with the following format: "../checkpoint_fold_{fold_number}.pt"
Also, I added comments, to the original class and black formatted the script.
Finally, I updated the .gitignore file to ignore the .DS_Store files that are automatically created in MacOS.
I hope you find this update as helpful as your work has been to me. :D