pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

parsing issue with `save_last` parameter of `ModelCheckpoint`

Open mariovas3 opened this issue 10 months ago • 0 comments

Bug description

Cannot pass a boolean to the save_last parameter of the ModelCheckpoint callback using LightningCLI.

All parameters work fine except for save_last. I think jsonargparse is having trouble with the validation of the annotation of save_last which is currently Optional[Literal[True, False, 'link']].

Ideally, this should work like any other boolean flag e.g., like --my_model_checkpoint.verbose=false.

I have already forked the project and proposed a solution together with tests in the relevant directory. I have readiness to submit a PR if you think this might be useful.

What version are you seeing the problem on?

master

How to reproduce the bug

import inspect
import jsonargparse
from lightning.pytorch.callbacks import ModelCheckpoint

val = 'true'
annot = inspect.signature(ModelCheckpoint).parameters["save_last"].annotation
parser = jsonargparse.ArgumentParser()
parser.add_argument("--a", type=annot)
args = parser.parse_args(["--a", val])

Error messages and logs

error: Parser key "a":
  Does not validate against any of the Union subtypes
  Subtypes: (typing.Literal[True, False, 'link'], <class 'NoneType'>)
  Errors:
    - Expected a typing.Literal[True, False, 'link']
    - Expected a <class 'NoneType'>. Got value: True
  Given value type: <class 'str'>
  Given value: true

Environment

Current environment ``` * CUDA: - GPU: None - available: False - version: None * Lightning: - lightning-utilities: 0.11.2 - torch: 2.2.2+cpu - torchmetrics: 1.2.1 - torchvision: 0.17.2+cpu * Packages: - aiohttp: 3.9.5 - aiosignal: 1.3.1 - antlr4-python3-runtime: 4.9.3 - attrs: 23.2.0 - certifi: 2024.2.2 - cfgv: 3.4.0 - charset-normalizer: 3.3.2 - cloudpickle: 2.2.1 - distlib: 0.3.8 - filelock: 3.13.4 - frozenlist: 1.4.1 - fsspec: 2023.4.0 - identify: 2.5.36 - idna: 3.7 - iniconfig: 2.0.0 - jinja2: 3.1.2 - jsonargparse: 4.28.0 - lightning-utilities: 0.11.2 - markupsafe: 2.1.3 - mpmath: 1.3.0 - multidict: 6.0.5 - networkx: 3.2.1 - nodeenv: 1.8.0 - numpy: 1.26.3 - omegaconf: 2.3.0 - packaging: 23.1 - pillow: 10.2.0 - pip: 24.0 - platformdirs: 4.2.0 - pluggy: 1.5.0 - pre-commit: 3.7.0 - pytest: 7.4.0 - pyyaml: 6.0.1 - requests: 2.31.0 - setuptools: 68.2.2 - sympy: 1.12 - torch: 2.2.2+cpu - torchmetrics: 1.2.1 - torchvision: 0.17.2+cpu - tqdm: 4.66.2 - typing-extensions: 4.8.0 - urllib3: 2.2.1 - virtualenv: 20.25.3 - wheel: 0.41.2 - yarl: 1.9.4 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.11.7 - release: 5.15.0-102-generic - version: #112~20.04.1-Ubuntu SMP Thu Mar 14 14:28:24 UTC 2024 ```

More info

The solution is to change the typing annotation of the save_last parameter in the constructor of ModelCheckpoint.

I have made a draft PR and added a test to check that the bug is fixed here fore reference.

The tests are passing:

fixing-save-last-bug

mariovas3 avatar Apr 22 '24 15:04 mariovas3