pytorch-lightning
pytorch-lightning copied to clipboard
parsing issue with `save_last` parameter of `ModelCheckpoint`
Bug description
Cannot pass a boolean to the save_last
parameter of the ModelCheckpoint
callback using LightningCLI
.
All parameters work fine except for save_last
. I think jsonargparse
is having trouble with the validation of the annotation of save_last
which is currently Optional[Literal[True, False, 'link']]
.
Ideally, this should work like any other boolean flag e.g., like --my_model_checkpoint.verbose=false
.
I have already forked the project and proposed a solution together with tests in the relevant directory. I have readiness to submit a PR if you think this might be useful.
What version are you seeing the problem on?
master
How to reproduce the bug
import inspect
import jsonargparse
from lightning.pytorch.callbacks import ModelCheckpoint
val = 'true'
annot = inspect.signature(ModelCheckpoint).parameters["save_last"].annotation
parser = jsonargparse.ArgumentParser()
parser.add_argument("--a", type=annot)
args = parser.parse_args(["--a", val])
Error messages and logs
error: Parser key "a":
Does not validate against any of the Union subtypes
Subtypes: (typing.Literal[True, False, 'link'], <class 'NoneType'>)
Errors:
- Expected a typing.Literal[True, False, 'link']
- Expected a <class 'NoneType'>. Got value: True
Given value type: <class 'str'>
Given value: true
Environment
Current environment
``` * CUDA: - GPU: None - available: False - version: None * Lightning: - lightning-utilities: 0.11.2 - torch: 2.2.2+cpu - torchmetrics: 1.2.1 - torchvision: 0.17.2+cpu * Packages: - aiohttp: 3.9.5 - aiosignal: 1.3.1 - antlr4-python3-runtime: 4.9.3 - attrs: 23.2.0 - certifi: 2024.2.2 - cfgv: 3.4.0 - charset-normalizer: 3.3.2 - cloudpickle: 2.2.1 - distlib: 0.3.8 - filelock: 3.13.4 - frozenlist: 1.4.1 - fsspec: 2023.4.0 - identify: 2.5.36 - idna: 3.7 - iniconfig: 2.0.0 - jinja2: 3.1.2 - jsonargparse: 4.28.0 - lightning-utilities: 0.11.2 - markupsafe: 2.1.3 - mpmath: 1.3.0 - multidict: 6.0.5 - networkx: 3.2.1 - nodeenv: 1.8.0 - numpy: 1.26.3 - omegaconf: 2.3.0 - packaging: 23.1 - pillow: 10.2.0 - pip: 24.0 - platformdirs: 4.2.0 - pluggy: 1.5.0 - pre-commit: 3.7.0 - pytest: 7.4.0 - pyyaml: 6.0.1 - requests: 2.31.0 - setuptools: 68.2.2 - sympy: 1.12 - torch: 2.2.2+cpu - torchmetrics: 1.2.1 - torchvision: 0.17.2+cpu - tqdm: 4.66.2 - typing-extensions: 4.8.0 - urllib3: 2.2.1 - virtualenv: 20.25.3 - wheel: 0.41.2 - yarl: 1.9.4 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.11.7 - release: 5.15.0-102-generic - version: #112~20.04.1-Ubuntu SMP Thu Mar 14 14:28:24 UTC 2024 ```More info
The solution is to change the typing annotation of the save_last
parameter in the constructor of ModelCheckpoint
.
I have made a draft PR and added a test to check that the bug is fixed here fore reference.
The tests are passing: