pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

Non strict checkpoint loading

Open anshkumar opened this issue 5 months ago • 1 comments

Description & Motivation

Currently, I'm finetuning my model which are based on some very large models (like LLAMA and CLIP). I'm using some light weight adapters for fine tuning. The current checkpoint behaviour of the code is to save all the state_dict in checkpoint file. This causes a creation of a very large file (also multiple versioning leads to a lot of space consumption). What I want is to save only custom part of the weight. This is possible by using on_save_checkpoint function. My current on_save_checkpoint looks like as follows:

def on_save_checkpoint(self, checkpoint):
    trainable = OrderedDict()
    for n, p in checkpoint['state_dict'].items():
        if 'adapter' in n:
            trainable[n] = p.data
    checkpoint['state_dict'] = trainable

The problem with this is that while resuming the training, the parameter which I ignored while saving are required for training to resume. What I want is to load the checkpoint with strict set as False. Current lightning Trainer does not allow this.

Currently, I'm manually adding strict=False in the following line.

Pitch

I want to have strict parameter in Trainer as well, which allows loading checkpoint skipping some parameters.

Alternatives

I'm manually adding strict=False in the following line.

Additional context

No response

cc @borda

anshkumar avatar Jan 14 '24 07:01 anshkumar