ttools
ttools copied to clipboard
Adds functionality to save best model in CheckpointingCallback and fixes a typo
Hello everyone. I have been using torch tools and love it for its simplicity and readability. For example, I tried using pytorch-lightning once instead but I just became nervous because it was harder to see what it is doing behind the scenes. So I quickly came back to ttools. Anyway, I needed a solution for the above problem. It is pretty similar to the above proposal. There are a couple of differences:
- This class saves the best N.
- Sense of how to interpret the val_data metric (whether to maximize/minimize) is included in class instantiation.
Current limitation is that if two models have the exact same score, it'll delete one of them. Hope this is useful
class CheckpointingBestNCallback (Callback) :
""" A callback which saves the best N models.
Args:
checkpointer (Checkpointer): actual checkpointer responsible for the I/O
key: key into accumulated validation data to define metric.
N (int, optional): number of models to save
sense (string, optional): one of "maximize"/"minimize".
Denoting whether we want to maximize/minimize val metric.
"""
BEST_PREFIX = "best_"
def __init__ (self, checkpointer, key, N=3, sense="maximize") :
super(CheckpointingBestNCallback, self).__init__()
self.checkpointer = checkpointer
self.key = key
self.N = N
self.sense = sense
self.default = sense == "maximize"
self.cmp = lambda x, y : x > y if self.default else y > x
self.ckptDict = dict()
def validation_end(self, val_data):
super(CheckpointingBestNCallback, self).validation_end(val_data)
score = val_data[self.key]
isBetter = any([self.cmp(score, y) for y in self.ckptDict.keys()])
if len(self.ckptDict) < self.N or isBetter :
path = "{}{:.3f}".format(CheckpointingBestNCallback.BEST_PREFIX, score)
path = path.replace('.', '-')
self.checkpointer.save(path, extras=dict(score=score))
self.ckptDict[score] = path
self.__purge_old_files()
def __purge_old_files(self) :
"""Delete checkpoints that are beyond the max to keep."""
chkpts = os.listdir(self.checkpointer.root)
toBeRemoved = sorted(self.ckptDict.keys(), reverse=self.default)[self.N:]
for s in toBeRemoved :
cpref = self.ckptDict[s]
cname = [fname for fname in chkpts if cpref in fname].pop()
self.checkpointer.delete(cname)
self.ckptDict.pop(s)