pytorch-lightning
pytorch-lightning copied to clipboard
Non strict checkpoint loading
Description & Motivation
Currently, I'm finetuning my model which are based on some very large models (like LLAMA and CLIP). I'm using some light weight adapters for fine tuning. The current checkpoint behaviour of the code is to save all the state_dict
in checkpoint file. This causes a creation of a very large file (also multiple versioning leads to a lot of space consumption). What I want is to save only custom part of the weight. This is possible by using on_save_checkpoint
function. My current on_save_checkpoint
looks like as follows:
def on_save_checkpoint(self, checkpoint):
trainable = OrderedDict()
for n, p in checkpoint['state_dict'].items():
if 'adapter' in n:
trainable[n] = p.data
checkpoint['state_dict'] = trainable
The problem with this is that while resuming the training, the parameter which I ignored while saving are required for training to resume. What I want is to load the checkpoint with strict
set as False
. Current lightning Trainer does not allow this.
Currently, I'm manually adding strict=False
in the following line.
Pitch
I want to have strict
parameter in Trainer as well, which allows loading checkpoint skipping some parameters.
Alternatives
I'm manually adding strict=False
in the following line.
Additional context
No response
cc @borda