ttools icon indicating copy to clipboard operation
ttools copied to clipboard

Adds functionality to save best model in CheckpointingCallback and fixes a typo

Open dmsm opened this issue 5 years ago • 1 comments

dmsm avatar Jun 04 '20 19:06 dmsm

Hello everyone. I have been using torch tools and love it for its simplicity and readability. For example, I tried using pytorch-lightning once instead but I just became nervous because it was harder to see what it is doing behind the scenes. So I quickly came back to ttools. Anyway, I needed a solution for the above problem. It is pretty similar to the above proposal. There are a couple of differences:

  1. This class saves the best N.
  2. Sense of how to interpret the val_data metric (whether to maximize/minimize) is included in class instantiation.

Current limitation is that if two models have the exact same score, it'll delete one of them. Hope this is useful

class CheckpointingBestNCallback (Callback) : 
    """ A callback which saves the best N models.

    Args:
        checkpointer (Checkpointer): actual checkpointer responsible for the I/O
        key: key into accumulated validation data to define metric.
        N (int, optional): number of models to save
        sense (string, optional): one of "maximize"/"minimize". 
            Denoting whether we want to maximize/minimize val metric.
    """
    
    BEST_PREFIX = "best_"

    def __init__ (self, checkpointer, key, N=3, sense="maximize") : 
        super(CheckpointingBestNCallback, self).__init__()
        self.checkpointer = checkpointer
        self.key = key
        self.N = N
        self.sense = sense
        self.default = sense == "maximize"
        self.cmp = lambda x, y : x > y if self.default else y > x
        self.ckptDict = dict()

    def validation_end(self, val_data): 
        super(CheckpointingBestNCallback, self).validation_end(val_data)
        score = val_data[self.key] 
        isBetter = any([self.cmp(score, y) for y in self.ckptDict.keys()])
        if len(self.ckptDict) < self.N or isBetter : 
            path = "{}{:.3f}".format(CheckpointingBestNCallback.BEST_PREFIX, score)
            path = path.replace('.', '-')
            self.checkpointer.save(path, extras=dict(score=score))
            self.ckptDict[score] = path
            self.__purge_old_files()

    def __purge_old_files(self) : 
        """Delete checkpoints that are beyond the max to keep."""
        chkpts = os.listdir(self.checkpointer.root)
        toBeRemoved = sorted(self.ckptDict.keys(), reverse=self.default)[self.N:]
        for s in toBeRemoved : 
            cpref = self.ckptDict[s]
            cname = [fname for fname in chkpts if cpref in fname].pop()
            self.checkpointer.delete(cname)
            self.ckptDict.pop(s)

Vrroom avatar Jun 28 '21 06:06 Vrroom