swift-models icon indicating copy to clipboard operation
swift-models copied to clipboard

Is there any documentation on how to use the CheckpointWriter API?

Open SibtainRazaJamali opened this issue 4 years ago • 1 comments

Is there any resource/notebook on Using CheckpointWriter api inside a training loop for any model in swift?

SibtainRazaJamali avatar May 16 '20 16:05 SibtainRazaJamali

Sorry about the lacking documentation, that's something I was hoping to get to with overall improvements to the use of pretrained checkpoints. For now, the best example is in the GPT2 model, which has the writeCheckpoint(to:name:) method as a prototype for what we'd like to extend to other models.

The CheckpointWriter itself has a reasonably simple interface, taking in a dictionary of String names and Float Tensors corresponding to those names. It then can write out a checkpoint from that.

The process of getting names and tensors from within a model can vary, which is what we're trying to make more consistent and easier to use. The method utilized for the GPT2 model is contained here. The ExportableLayer protocol maps the names of properties within the model to their names within the checkpoint, and the recursivelyObtainTensors() function uses Mirror to iterate over the structure of the model and sublayers to apply this name mapping to the Tensors within. This generates the dictionary that is then passed to the CheckpointWriter.

A similar system could be configured for other models, so we're looking at building a generalized implementation of something like this to make serialization easy. Sorry it's undocumented and a little barebones right now.

BradLarson avatar May 18 '20 13:05 BradLarson