Inconcistency in loading from checkpoint in LightningCLI
Bug description
When using a checkpoint in LightningCLI, the model is first instantiated and then the checkpoint is loaded by supplying it to the Trainer's method's ckpt_path argument.
The problem is that hyperparameters in the checkpoint are not used when instantiating the model, and thus when allocating tensors, which can cause checkpoint loading to fail if tensor sizes do not match. Furthermore, if there is complicated instantiation logic in the model, this may lead to other silent bugs or failures.
This was first raised as a discussion in #20715
What version are you seeing the problem on?
v2.5
How to reproduce the bug
Here is a minimal example, where the predict method is used. We modify the out_dim in fit, so that the last layer has a different size, causing loading in predict to fail.
# cli.py
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.boring_classes import DemoModel, BoringDataModule
class DemoModelWithHyperparameters(DemoModel):
def __init__(self, *args, **kwargs):
self.save_hyperparameters()
super().__init__(*args, **kwargs)
def cli_main():
cli = LightningCLI(DemoModelWithHyperparameters, BoringDataModule)
if __name__ == "__main__":
cli_main()
and then run
$ python src/lightning_cli_load_checkpoint/cli.py fit --trainer.max_epochs 1 --model.out_dim 2
$ python src/lightning_cli_load_checkpoint/cli.py predict --ckpt_path <path_to_checkpoint>
Error messages and logs
Restoring states from the checkpoint path at lightning_logs/version_23/checkpoints/epoch=0-step=64.ckpt
Traceback (most recent call last):
File ".../lightning_cli_load_checkpoint/src/lightning_cli_load_checkpoint/cli.py", line 20, in <module>
cli_main()
File ".../lightning_cli_load_checkpoint/src/lightning_cli_load_checkpoint/cli.py", line 16, in cli_main
cli = MyLightningCLI(DemoModelWithHyperparameters, datamodule_class=BoringDataModule)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/cli.py", line 398, in __init__
self._run_subcommand(self.subcommand)
File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/cli.py", line 708, in _run_subcommand
fn(**fn_kwargs)
File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 887, in predict
return call._call_and_handle_interrupt(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 48, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 928, in _predict_impl
results = self._run(model, ckpt_path=ckpt_path)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)
File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 409, in _restore_modules_and_callbacks
self.restore_model()
File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 286, in restore_model
self.trainer.strategy.load_model_state_dict(
File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 372, in load_model_state_dict
self.lightning_module.load_state_dict(checkpoint["state_dict"], strict=strict)
File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 2581, in load_state_dict
raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for DemoModelWithHyperparameters:
size mismatch for l1.weight: copying a param with shape torch.Size([2, 32]) from checkpoint, the shape in current model is torch.Size([10, 32]).
size mismatch for l1.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([10]).
Expected behavior
I'd expect the loading to respect the checkpoint's arguments. In other words, while the current implementation roughly follows this logic:
model = Model(**cli_args)
Trainer().predict(model, data, ckpt_path=ckpt_path)
I'd expect it to be closer to
model = Model.load_from_checkpoint(ckpt_path, **cli_args)
Trainer().predict(model, data, ckpt_path=ckpt_path)
Environment
Current environment
- CUDA:
- GPU: None
- available: False
- version: None
- Lightning:
- lightning: 2.5.1
- lightning-cli-load-checkpoint: 0.1.0
- lightning-utilities: 0.14.3
- pytorch-lightning: 2.5.1
- torch: 2.6.0
- torchmetrics: 1.7.1
- Packages:
- aiohappyeyeballs: 2.6.1
- aiohttp: 3.11.16
- aiosignal: 1.3.2
- antlr4-python3-runtime: 4.9.3
- attrs: 25.3.0
- autocommand: 2.2.2
- backports.tarfile: 1.2.0
- contourpy: 1.3.1
- cycler: 0.12.1
- docstring-parser: 0.16
- filelock: 3.18.0
- fonttools: 4.57.0
- frozenlist: 1.5.0
- fsspec: 2025.3.2
- hydra-core: 1.3.2
- idna: 3.10
- importlib-metadata: 8.0.0
- importlib-resources: 6.5.2
- inflect: 7.3.1
- jaraco.collections: 5.1.0
- jaraco.context: 5.3.0
- jaraco.functools: 4.0.1
- jaraco.text: 3.12.1
- jinja2: 3.1.6
- jsonargparse: 4.38.0
- kiwisolver: 1.4.8
- lightning: 2.5.1
- lightning-cli-load-checkpoint: 0.1.0
- lightning-utilities: 0.14.3
- markdown-it-py: 3.0.0
- markupsafe: 3.0.2
- matplotlib: 3.10.1
- mdurl: 0.1.2
- more-itertools: 10.3.0
- mpmath: 1.3.0
- multidict: 6.4.3
- networkx: 3.4.2
- numpy: 2.2.4
- omegaconf: 2.3.0
- packaging: 24.2
- pillow: 11.2.1
- platformdirs: 4.2.2
- propcache: 0.3.1
- protobuf: 6.30.2
- pygments: 2.19.1
- pyparsing: 3.2.3
- python-dateutil: 2.9.0.post0
- pytorch-lightning: 2.5.1
- pyyaml: 6.0.2
- rich: 13.9.4
- setuptools: 78.1.0
- six: 1.17.0
- sympy: 1.13.1
- tensorboardx: 2.6.2.2
- tomli: 2.0.1
- torch: 2.6.0
- torchmetrics: 1.7.1
- tqdm: 4.67.1
- typeguard: 4.3.0
- typeshed-client: 2.7.0
- typing-extensions: 4.13.2
- wheel: 0.45.1
- yarl: 1.19.0
- zipp: 3.19.2
- System:
- OS: Darwin
- architecture:
- 64bit
- processor: arm
- python: 3.12.7
- release: 24.4.0
- version: Darwin Kernel Version 24.4.0: Fri Apr 11 18:33:47 PDT 2025; root:xnu-11417.101.15~117/RELEASE_ARM64_T6000
After thinking a bit I think that implementing this is not trivial. The main difficulty comes from the argument links. Regardless on how the code is changed, it should work well when argument links are added to the parser. I am not sure if I am aware of all the details that need to be considered. And could be that some are only noticed while implementing. For now I would propose to do the following:
- Behavior only changes when
LightningCLIis working in subcommands mode, i.e.run=True. - Before instantiation and running, check whether
ckpt_pathis passed to the subcommand. Note that in principleckpt_pathcould be defined as a command line argument, in a config file or as an environment variable. So a simple solution could be to parse, check ifckpt_pathset, and if set, a second parse would be needed. - If
ckpt_pathis set, usetorch.loadto read the checkpoint, and check if hyperparameters are included (i.e.save_hyperparameterswas used). - If hyperparameters are included, remove the keys that correspond to link targets (both applied on parse and on instantiate). Unfortunately, right now there is no official way (jsonargparse public API) to know which keys are link targets.
- After removing link targets, parse again, but modifying the args such that right after the subcommand (e.g.
predict) and before all other arguments, there is a new--configoption with value the modified hyperparameters from the checkpoint. - Continue the normal flow which would instantiate classes and then run the trainer method.
I need to figure out what to do about point 4. Most likely new feature in jsonargparse is needed.
Hi is there any progress on this? I think this is a very big issue. How can the CLI even be useful if we can't run a simple a fit and test?
python main.py fit --model.out_dim 2
python main.py test --ckpt_path path/to/ckpt
Hey @Borda can you confirm with this bug and assign me to work on it. Thanks😊
I was looking into "remove the keys that correspond to link targets" from my comment above, and it seems it isn't an issue as I thought. This means that the implementation is easy and there is no need for changes in jsonargparse. I have created #21116 with the fix.
Since review+merge+release can take a while, below is a temporal patch for people that don't want to wait.
class MyCLI(LightningCLI):
def parse_arguments(self, *args, **kwargs):
if hasattr(LightningCLI, "_parse_ckpt_path"):
raise RuntimeError("_parse_ckpt_path already available in lightning, this patch is no longer needed")
super().parse_arguments(*args, **kwargs)
self._parse_ckpt_path()
def _parse_ckpt_path(self):
"""If a checkpoint path is given, parse the hyperparameters from the checkpoint and update the config."""
if not self.config.get("subcommand"):
return
ckpt_path = self.config[self.config.subcommand].get("ckpt_path")
if ckpt_path and Path(ckpt_path).is_file():
ckpt = torch.load(ckpt_path, weights_only=True, map_location="cpu")
hparams = ckpt.get("hyper_parameters", {})
hparams.pop("_instantiator", None)
if hparams:
hparams = {self.config.subcommand: {"model": hparams}}
self.config = self.parser.parse_object(hparams, self.config)