skorch icon indicating copy to clipboard operation
skorch copied to clipboard

Checkpoint callbacks with n_jobs > 1

Open andreapi87 opened this issue 2 years ago • 6 comments

I would to use gridsearchcv in parallel mode. However, I think that the checkpoints used during the trainings in the different processes could override each other between them since they have the same filenames defined in the Checkpoint attributes f_params ecc. My first attempt was subclassing the Checkpoint class and implementing a semaphore in the on_train_begin method that changed the filenames (using the fn_prefix) using a global variable as job counter. However, the jobs are viewed as processes and not as threads, so my solution did not work. My present attempt is to store the counter in file, protected by a filelock. Is there a better way?

In the following my solution:

import os
from filelock import FileLock
class CheckPointAndRestore(skorch.callbacks.Checkpoint):
    objectCounter = 0
    TMP_PATH      = './checkpoints'
    COUNTER_FILE  = f'{TMP_PATH}/counter.txt'
    LOCK_FILE     = f'{TMP_PATH}/counter.txt.lock'
    LOCK          = FileLock(LOCK_FILE)
    def __init__(self, *wargs, best_net=None,
                 epoch_counter = -1,
                 dirname  = '', 
                 load_best= True,
                 sink     = print,
                 **kwargs):
        super().__init__(*wargs,
                         dirname   = CheckPointAndRestore.TMP_PATH,
                         load_best = load_best,
                         sink      = sink,
                         **kwargs)
        self.epoch_counter = epoch_counter
        
        if not os.path.exists(f'{CheckPointAndRestore.LOCK_FILE}'):
            open(f'{CheckPointAndRestore.LOCK_FILE}','a').close() 
 
    def initialize(self):
        self.epoch_counter = -1
        self.best_net      = None
        
        with self.LOCK:      
            if not os.path.exists(CheckPointAndRestore.TMP_PATH):
                os.makedirs(CheckPointAndRestore.TMP_PATH)    
            if not os.path.exists(f'{CheckPointAndRestore.COUNTER_FILE}'):
                with open(f'{CheckPointAndRestore.COUNTER_FILE}',"w") as f:
                    f.write('0')
        return super().initialize()

    def on_epoch_begin(self, net, dataset_train=None, dataset_valid=None, **kwargs):
        super().on_epoch_begin(net, dataset_train, dataset_valid, **kwargs)
        self.epoch_counter += 1
    
    
    def on_train_begin(self, net, **kwargs):
        ret  = super().on_train_begin(net,**kwargs)
        self = self.initialize()
        
        with self.LOCK: 
            with open(f'{CheckPointAndRestore.COUNTER_FILE}', "r") as f:
                CheckPointAndRestore.objectCounter  =f.read()  
            CheckPointAndRestore.objectCounter = int(CheckPointAndRestore.objectCounter)
            CheckPointAndRestore.objectCounter += 1
            self.fn_prefix = f'{CheckPointAndRestore.objectCounter}'
            with open(f'{CheckPointAndRestore.COUNTER_FILE}', "w") as f:
                f.write(str(CheckPointAndRestore.objectCounter))
        
        return ret


andreapi87 avatar Mar 25 '22 16:03 andreapi87

Without having looked at your solution: Are you sure that you need a checkpoint for grid search (or any hyper-parameter search) at all? In general, the process looks like this:

  1. Define the hyper parameter range to test
  2. Run a bunch of combinations on splits of data, record the results, discard the trained models
  3. Once the best hyper parameters are discovered, train on the whole dataset using these parameters

Only for this last step would you typically want to have checkpoints. You can either not use checkpointing at all during the grid search and perform this last step manually. Or you can set refit=True in the grid search, which will run this step automatically, and in the process override the checkpointed models with the one trained on the best hyper parameters.

Maybe you have a different use case in mind that would actually require one checkpoint for each hyper parameter combinations, but I wanted to ensure first if it's even necessary.

BenjaminBossan avatar Mar 25 '22 23:03 BenjaminBossan

I understand your point. But consider the following case with gridsearch + k-fold cross validation: a 2-fold cv (just as example) starts in parallel;

  • model 1 training on fold 1 makes a checkpoint saving its best weights;
  • in parallel, model 2 training on fold 2 makes a new checkpoint saving its weights (overriding the checkpoint made by the model 1 since they have the same filenames)
  • model 1 stops training after a given patience for validation loss too high, trying to restore best weights, but instead it restores wrong ones (since they have been overwritten by model 2) and the accuracy is computed on fold 2 using the wrong weights;
  • model 2 stops training for the same reason restoring its best weights and the accuracy is computed on fold 1;
  • mean accuracy is computed between the test accuracy computed on the model 2 (real) and the test accuracy computed on the model 1 (fake) Is it right?

andreapi87 avatar Mar 26 '22 08:03 andreapi87

I have had similiar issues. What seems to work for me is to instead of using GridSearchCV directly, setting the parameters through sklearns ParameterGrid - enumerate the parameters and folds - and then use the Checkpoint fn_prefix variable with the enumeration to make sure that the checkpoint-data are unique.

In a bit more detail:

  • Make as many NeuralNetClassifiers as you want folds, and store them in a list cvestimators (for reference below). For other reasons where GridSearchCV doesn't work out of the box, I have to do individual train_split = predefined_split(special_split_data_for_each_fold_here)
  • Use joblib Parallel to iterate over (gind,g) for gind,g in enumerate(list(itertools.product(cvestimators,ParameterGrid(param_grid)))), where param_grid is the dict you feed to GridSearchCV - pass (gind,g) to the function you parallelize over
  • Before you run g[0].fit in the individual functions you are parallelizing over, do g[0].set_params(g[1].update({"callbacks__Checkpoint__fn_prefix":str(gind)}).

tarjebargheer avatar Mar 28 '22 10:03 tarjebargheer

I have had similiar issues. What seems to work for me is to instead of using GridSearchCV directly, setting the parameters through sklearns ParameterGrid - enumerate the parameters and folds - and then use the Checkpoint fn_prefix variable with the enumeration to make sure that the checkpoint-data are unique.

In a bit more detail:

* Make as many `NeuralNetClassifier`s as you want folds, and store them in a list `cvestimators` (for reference below). For other reasons where GridSearchCV doesn't work out of the box, I have to do individual `train_split = predefined_split(special_split_data_for_each_fold_here)`

* Use joblib `Parallel` to iterate over `(gind,g) for gind,g in enumerate(list(itertools.product(cvestimators,ParameterGrid(param_grid))))`, where `param_grid` is the dict you feed to `GridSearchCV` - pass `(gind,g)` to the function you parallelize over

* Before you run `g[0].fit` in the individual functions you are parallelizing over, do `g[0].set_params(g[1].update({"callbacks__Checkpoint__fn_prefix":str(gind)})`.

Thanks for your suggestion. It seems that your solution does not use GridsearchCV. However, I prefer to avoid manually set all the loops and models, so I would like to use gridsearchcv + kfold of scikit. Is the solution that I adopted in the previous post (i.e., a callback using a file to store the prefix filenames) valid or is there a better solution?

andreapi87 avatar Mar 28 '22 10:03 andreapi87

Okay, the use case of having Checkpoint + grid search with n_jobs>1 is an occasion where this indeed becomes relevant. Depending on the system you're using, you might still see no benefit with n_jobs>1 when you're already parallelizing the neural net through PyTorch, but it depends.

I took a look at the code for CheckPointAndRestore and even though it's a bit hacky, it looks like it should work. However, I think it would be better to have a built in solution for the issue. The first thing that comes to my mind would be this:

Right now, fn_prefix can only be a string. If we would also allow it to be a callable that returns a string, that could solve the problem. Then a user could pass a function that uses a counter with a lock, similar to what was suggested above, or even just a string containing a random element (rng, timestamp).

@ottonemo WDYT? I vaguely remember that we discussed something like this at some point.

BenjaminBossan avatar Mar 29 '22 21:03 BenjaminBossan

I have run into the same issue with conflicting checkpoint files while using n_jobs>1 cross validation. I created a subclass of Checkpoint that uses the unique object id as the file name prefix so there is no name conflict; it seems to make GridSearchCV and cross_validate operate as I expected.

class UniquePrefixCheckpoint(skorch.callbacks.Checkpoint):
    def initialize(self):
        self.fn_prefix = str(id(self))
        return super(UniquePrefixCheckpoint, self).initialize()

kylepin avatar Jul 11 '22 23:07 kylepin