pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

`link_arguments` does not work in lightning 2.3

Open peacekurella opened this issue 1 year ago • 7 comments

Bug description

When using parser.link_arguments to link fields a & b with apply_on="instantiate", it does not populate the field b when it is accessed later. This was not a problem in lightning 2.2.5 as we are using it currently. However upgrading it to 2.3.x causes field b to not be populated.

What version are you seeing the problem on?

2.3.3

How to reproduce the bug

https://github.com/Lightning-AI/pytorch-lightning/issues/20147#issuecomment-2266215234

Error messages and logs

Environment

Current environment
  • CUDA: - GPU: None - available: False - version: 12.1
  • Lightning: - lightning: 2.2.5 - lightning-utilities: 0.11.2 - pytorch-lightning: 2.3.0 - torch: 2.3.1 - torchmetrics: 1.4.0.post0
  • Packages: - aiobotocore: 2.7.0 - aiohttp: 3.9.5 - aioitertools: 0.7.1 - aiosignal: 1.2.0 - alabaster: 0.7.16 - altair: 5.0.1 - anyio: 4.2.0 - appdirs: 1.4.4 - argon2-cffi: 21.3.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.2.3 - astroid: 2.14.2 - astropy: 6.1.0 - astropy-iers-data: 0.2024.6.3.0.31.14 - asttokens: 2.0.5 - async-lru: 2.0.4 - async-timeout: 4.0.3 - atomicwrites: 1.4.0 - attrs: 23.1.0 - automat: 20.2.0 - autopep8: 2.0.4 - babel: 2.11.0 - bcrypt: 3.2.0 - beautifulsoup4: 4.12.3 - binaryornot: 0.4.4 - black: 24.4.2 - bleach: 4.1.0 - blinker: 1.6.2 - bokeh: 3.4.1 - boto3: 1.34.131 - botocore: 1.34.131 - bottleneck: 1.3.7 - brotli: 1.0.9 - cachetools: 5.3.3 - cattrs: 23.2.3 - certifi: 2024.6.2 - cffi: 1.16.0 - chardet: 4.0.0 - charset-normalizer: 3.3.2 - click: 8.1.7 - cloudpickle: 2.2.1 - colorama: 0.4.6 - colorcet: 3.1.0 - comm: 0.2.1 - constantly: 23.10.4 - contourpy: 1.2.0 - cookiecutter: 2.6.0 - cryptography: 42.0.5 - cssselect: 1.2.0 - cycler: 0.11.0 - cytoolz: 0.12.2 - dask: 2024.5.0 - dask-expr: 1.1.0 - datasets: 2.14.6 - datashader: 0.16.2 - debugpy: 1.6.7 - decorator: 5.1.1 - defusedxml: 0.7.1 - diff-match-patch: 20200713 - dill: 0.3.7 - distributed: 2024.5.0 - docker: 7.1.0 - docstring-parser: 0.16 - docstring-to-markdown: 0.11 - docutils: 0.18.1 - entrypoints: 0.4 - et-xmlfile: 1.1.0 - exceptiongroup: 1.2.0 - executing: 0.8.3 - fastjsonschema: 2.16.2 - filelock: 3.13.1 - flake8: 7.0.0 - flask: 3.0.3 - fonttools: 4.51.0 - frozenlist: 1.4.0 - fsspec: 2023.10.0 - gensim: 4.3.2 - gitdb: 4.0.7 - gitpython: 3.1.37 - gmpy2: 2.1.2 - google-pasta: 0.2.0 - greenlet: 3.0.1 - h5py: 3.11.0 - heapdict: 1.0.1 - holoviews: 1.19.0 - huggingface-hub: 0.23.4 - hvplot: 0.10.0 - hyperlink: 21.0.0 - idna: 3.7 - imagecodecs: 2023.1.23 - imageio: 2.33.1 - imagesize: 1.4.1 - imbalanced-learn: 0.12.3 - importlib-metadata: 6.11.0 - importlib-resources: 6.4.0 - incremental: 22.10.0 - inflection: 0.5.1 - iniconfig: 1.1.1 - intake: 0.7.0 - intervaltree: 3.1.0 - ipykernel: 6.28.0 - ipython: 8.25.0 - ipython-genutils: 0.2.0 - ipywidgets: 7.6.5 - isort: 5.13.2 - itemadapter: 0.3.0 - itemloaders: 1.1.0 - itsdangerous: 2.2.0 - jaraco.classes: 3.2.1 - jedi: 0.18.1 - jeepney: 0.7.1 - jellyfish: 1.0.1 - jinja2: 3.1.4 - jmespath: 1.0.1 - joblib: 1.4.2 - json5: 0.9.6 - jsonargparse: 4.30.0 - jsonschema: 4.19.2 - jsonschema-specifications: 2023.7.1 - jupyter: 1.0.0 - jupyter-client: 8.6.0 - jupyter-console: 6.6.3 - jupyter-core: 5.5.0 - jupyter-events: 0.10.0 - jupyter-lsp: 2.2.0 - jupyter-server: 2.10.0 - jupyter-server-terminals: 0.4.4 - jupyterlab: 4.0.11 - jupyterlab-pygments: 0.1.2 - jupyterlab-server: 2.25.1 - jupyterlab-widgets: 3.0.10 - keyring: 24.3.1 - kiwisolver: 1.4.4 - klon: 2.3.0 - lazy-loader: 0.4 - lazy-object-proxy: 1.10.0 - lckr-jupyterlab-variableinspector: 3.1.0 - lightning: 2.2.5 - lightning-utilities: 0.11.2 - linkify-it-py: 2.0.0 - llvmlite: 0.42.0 - lmdb: 1.4.1 - locket: 1.0.0 - lsprotocol: 2023.0.1 - lxml: 4.9.4 - lxml-stubs: 0.1.1 - lz4: 4.3.2 - markdown: 3.4.1 - markdown-it-py: 2.2.0 - markupsafe: 2.1.3 - matplotlib: 3.8.4 - matplotlib-inline: 0.1.6 - mccabe: 0.7.0 - mdit-py-plugins: 0.3.0 - mdurl: 0.1.0 - mistune: 2.0.4 - mkl-fft: 1.3.8 - mkl-random: 1.2.4 - mkl-service: 2.4.0 - more-itertools: 10.1.0 - mpmath: 1.3.0 - msgpack: 1.0.3 - multidict: 6.0.4 - multipledispatch: 0.6.0 - multiprocess: 0.70.15 - mypy: 1.10.0 - mypy-extensions: 1.0.0 - nbclient: 0.8.0 - nbconvert: 7.10.0 - nbformat: 5.9.2 - nest-asyncio: 1.6.0 - networkx: 3.2.1 - nltk: 3.8.1 - notebook: 7.0.8 - notebook-shim: 0.2.3 - numba: 0.59.1 - numexpr: 2.8.7 - numpy: 1.26.4 - numpydoc: 1.7.0 - nvidia-cublas-cu12: 12.1.3.1 - nvidia-cuda-cupti-cu12: 12.1.105 - nvidia-cuda-nvrtc-cu12: 12.1.105 - nvidia-cuda-runtime-cu12: 12.1.105 - nvidia-cudnn-cu12: 8.9.2.26 - nvidia-cufft-cu12: 11.0.2.54 - nvidia-curand-cu12: 10.3.2.106 - nvidia-cusolver-cu12: 11.4.5.107 - nvidia-cusparse-cu12: 12.1.0.106 - nvidia-nccl-cu12: 2.20.5 - nvidia-nvjitlink-cu12: 12.5.40 - nvidia-nvtx-cu12: 12.1.105 - openpyxl: 3.1.2 - overrides: 7.4.0 - packaging: 23.2 - pandas: 2.2.2 - pandocfilters: 1.5.0 - panel: 1.4.4 - param: 2.1.0 - parsel: 1.8.1 - parso: 0.8.3 - partd: 1.4.1 - pathos: 0.3.1 - pathspec: 0.10.3 - patsy: 0.5.6 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 10.3.0 - pip: 24.0 - platformdirs: 3.10.0 - plotly: 5.22.0 - pluggy: 1.5.0 - ply: 3.11 - pox: 0.3.4 - ppft: 1.7.6.8 - prometheus-client: 0.14.1 - prompt-toolkit: 3.0.43 - protego: 0.1.16 - protobuf: 3.20.3 - psutil: 5.9.0 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - py-cpuinfo: 9.0.0 - pyarrow: 14.0.2 - pyasn1: 0.4.8 - pyasn1-modules: 0.2.8 - pycodestyle: 2.11.1 - pycparser: 2.21 - pyct: 0.5.0 - pycurl: 7.45.2 - pydeck: 0.8.0 - pydispatcher: 2.0.5 - pydocstyle: 6.3.0 - pyerfa: 2.0.1.4 - pyflakes: 3.2.0 - pygls: 1.3.1 - pygments: 2.15.1 - pylint: 2.16.2 - pylint-venv: 3.0.3 - pyls-spyder: 0.4.0 - pyodbc: 5.0.1 - pyopenssl: 24.0.0 - pyparsing: 3.0.9 - pyproj: 3.6.1 - pyqt5: 5.15.10 - pyqt5-sip: 12.13.0 - pyqtwebengine: 5.15.6 - pysocks: 1.7.1 - pytest: 8.2.2 - python-dateutil: 2.9.0.post0 - python-json-logger: 2.0.7 - python-lsp-black: 2.0.0 - python-lsp-jsonrpc: 1.1.2 - python-lsp-server: 1.10.0 - python-slugify: 5.0.2 - python-snappy: 0.6.1 - pytoolconfig: 1.2.6 - pytorch-lightning: 2.3.0 - pytz: 2024.1 - pyviz-comms: 3.0.2 - pywavelets: 1.5.0 - pyxdg: 0.27 - pyyaml: 6.0.1 - pyzmq: 25.1.2 - qdarkstyle: 3.2.3 - qstylizer: 0.2.2 - qtawesome: 1.2.2 - qtconsole: 5.5.1 - qtpy: 2.4.1 - queuelib: 1.6.2 - referencing: 0.30.2 - regex: 2023.10.3 - requests: 2.32.3 - requests-file: 1.5.1 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - rich: 13.3.5 - rope: 1.12.0 - rpds-py: 0.10.6 - rtree: 1.0.1 - ruff: 0.4.9 - ruff-lsp: 0.0.53 - s3fs: 2023.10.0 - s3transfer: 0.10.1 - sagemaker: 2.224.0 - schema: 0.7.7 - scikit-image: 0.23.2 - scikit-learn: 1.4.2 - scipy: 1.11.4 - scrapy: 2.11.1 - seaborn: 0.13.2 - secretstorage: 3.3.1 - send2trash: 1.8.2 - service-identity: 18.1.0 - setuptools: 69.5.1 - sip: 6.7.12 - six: 1.16.0 - smart-open: 5.2.1 - smdebug-rulesconfig: 1.0.1 - smmap: 4.0.0 - sniffio: 1.3.0 - snowballstemmer: 2.2.0 - sortedcontainers: 2.4.0 - soupsieve: 2.5 - sphinx: 7.3.7 - sphinxcontrib-applehelp: 1.0.2 - sphinxcontrib-devhelp: 1.0.2 - sphinxcontrib-htmlhelp: 2.0.0 - sphinxcontrib-jsmath: 1.0.1 - sphinxcontrib-qthelp: 1.0.3 - sphinxcontrib-serializinghtml: 1.1.10 - spyder: 5.5.1 - spyder-kernels: 2.5.0 - sqlalchemy: 2.0.30 - stack-data: 0.2.0 - statsmodels: 0.14.2 - streamlit: 1.32.0 - sympy: 1.12 - tables: 3.9.2 - tabulate: 0.9.0 - tblib: 1.7.0 - tenacity: 8.2.2 - tensorboardx: 2.6.2.2 - terminado: 0.17.1 - text-unidecode: 1.3 - textdistance: 4.2.1 - threadpoolctl: 2.2.0 - three-merge: 0.1.1 - tifffile: 2023.4.12 - tinycss2: 1.2.1 - tldextract: 3.2.0 - toml: 0.10.2 - tomli: 2.0.1 - tomlkit: 0.11.1 - toolz: 0.12.0 - torch: 2.3.1 - torchmetrics: 1.4.0.post0 - tornado: 6.4.1 - tqdm: 4.66.4 - traitlets: 5.14.3 - triton: 2.3.1 - twisted: 23.10.0 - typeshed-client: 2.5.1 - typing-extensions: 4.11.0 - tzdata: 2023.3 - uc-micro-py: 1.0.1 - ujson: 5.10.0 - unicodedata2: 15.1.0 - unidecode: 1.2.0 - urllib3: 2.0.7 - w3lib: 2.1.2 - watchdog: 4.0.1 - wcwidth: 0.2.5 - webencodings: 0.5.1 - websocket-client: 1.8.0 - werkzeug: 3.0.3 - whatthepatch: 1.0.2 - wheel: 0.43.0 - widgetsnbextension: 3.5.2 - wrapt: 1.14.1 - wurlitzer: 3.0.2 - xarray: 2023.6.0 - xxhash: 3.4.1 - xyzservices: 2022.9.0 - yapf: 0.40.2 - yarl: 1.9.3 - zict: 3.0.0 - zipp: 3.17.0 - zope.interface: 5.4.0
  • System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.13 - release: 5.10.220-188.869.amzn2int.x86_64 - version: #1 SMP Wed Jul 17 14:39:49 UTC 2024

More info

No response

cc @carmocca @mauvilsa

peacekurella avatar Aug 01 '24 02:08 peacekurella

I noticed that the drop down menu does not contain 2.3.x as part of the version selection.

peacekurella avatar Aug 01 '24 02:08 peacekurella

Hey @peacekurella can you please provide a code example based on https://github.com/Lightning-AI/pytorch-lightning/blob/master/examples/pytorch/bug_report/bug_report_model.py so we can verify it's not working?

awaelchli avatar Aug 01 '24 12:08 awaelchli

I can do that.

peacekurella avatar Aug 02 '24 01:08 peacekurella

import torch
from typing import Type, TypeVar
from lightning.pytorch import LightningModule
from torch.utils.data import DataLoader, Dataset
from lightning.pytorch.cli import LightningCLI
from lightning import LightningDataModule
from lightning.pytorch.callbacks import ModelCheckpoint

class MyLightningCLI(LightningCLI):
    def add_arguments_to_parser(self, parser):
        
        parser.add_argument("data.destinationaddressid_vocab_size", default=10)
        parser.add_argument("model.destinationaddressid_vocab_size")
        parser.add_argument("--ckpt_path_ex", type=str, default = None)

        parser.link_arguments(
            "data.destinationaddressid_vocab_size",
            "model.destinationaddressid_vocab_size",
            apply_on="instantiate",
        )
    
    def before_instantiate_classes(self) -> None:
        if self.config.ckpt_path_ex:
            print("restoring from checkpoint")
            # we are restoring from a checkpoint
            CheckpointModuleInstantiatiorCLI.before_instantiate_classes(self)

class MyDataModule(LightningDataModule):
    def __init__(self, destinationaddressid_vocab_size: int = None):
        super().__init__()
        self.destinationaddressid_vocab_size = destinationaddressid_vocab_size
        print(f"The value of destinationaddressid_vocab_size in data module is {destinationaddressid_vocab_size}")
    
    def train_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)

    def val_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)
    
    def predict_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

ModuleType = TypeVar("ModuleType")

class CheckpointModuleInstantiatiorCLI:
    def __init__(self, cli: LightningCLI):
        self.cli = cli

    def class_instantiator(self, class_type: Type[ModuleType], *args, **kwargs) -> ModuleType:
        if args:
            raise ValueError("Unexpected args")

        map_location = None if torch.cuda.is_available() else "cpu"
        defaults = self.cli.parser.get_defaults()
        if class_type == BoringModel:
            non_default_kwargs = {k: v for k, v in kwargs.items() if defaults.model.get(k) != v}
            return BoringModel.load_from_checkpoint(
                self.cli.config.ckpt_path_ex,
                map_location=map_location,
                **non_default_kwargs,
            )
        elif class_type == MyDataModule:
            non_default_kwargs = {k: v for k, v in kwargs.items() if defaults.data.get(k) != v}
            return MyDataModule.load_from_checkpoint(
                self.cli.config.ckpt_path_ex,
                map_location=map_location,
                **non_default_kwargs,
            )
        else:
            raise ValueError("Unexpected class")

    @staticmethod
    def before_instantiate_classes(cli: LightningCLI) -> None:
        instantiator = CheckpointModuleInstantiatiorCLI(cli)
        cli.parser.add_instantiator(instantiator.class_instantiator, BoringModel)
        cli.parser.add_instantiator(instantiator.class_instantiator, MyDataModule)



class BoringModel(LightningModule):
    def __init__(self, destinationaddressid_vocab_size):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.destinationaddressid_vocab_size = destinationaddressid_vocab_size
        self.save_hyperparameters()
        print(f"The value of destinationaddressid_vocab_size in model module is {self.destinationaddressid_vocab_size}")
        

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run(args):

    cli = MyLightningCLI(
        BoringModel,
        MyDataModule,
        args=args,
        trainer_defaults={"callbacks": [ModelCheckpoint(dirpath="ckpts")]},
        run=False,
    )

    cli.trainer.fit(
        model=cli.model,
        datamodule=cli.datamodule,
        ckpt_path=cli.config.ckpt_path_ex if cli.config.ckpt_path_ex else None,
    )

if __name__ == "__main__":
    run(args=None)

Running with lightning 2.2.5

  1. generate checkpoints python bug_report.py --data.destinationaddressid_vocab_size 15 --trainer.max_epoch=1 . This prints
The value of destinationaddressid_vocab_size in data module is 15
The value of destinationaddressid_vocab_size in model module is 15
  1. load the model from checkpoints python bug_report.py --trainer.max_epoch=2 --ckpt_path_ex ckpts/epoch=0-step=32.ckpt this prints
The value of destinationaddressid_vocab_size in data module is None
The value of destinationaddressid_vocab_size in model module is 15

Running with lightning 2.3.3

  1. generate checkpoints python bug_report.py --data.destinationaddressid_vocab_size 15 --trainer.max_epoch=1 . This prints
The value of destinationaddressid_vocab_size in data module is 15
The value of destinationaddressid_vocab_size in model module is 15
  1. load the model from checkpoints python bug_report.py --trainer.max_epoch=2 --ckpt_path_ex ckpts/epoch=0-step=32.ckpt this prints
The value of destinationaddressid_vocab_size in data module is 10
The value of destinationaddressid_vocab_size in model module is 10

peacekurella avatar Aug 02 '24 22:08 peacekurella

@awaelchli added the repro code and scenarios with outputs.

peacekurella avatar Aug 02 '24 22:08 peacekurella

Ok thanks @peacekurella. But the default value is 10, and in the second command you don't pass --data.destinationaddressid_vocab_size 15. When you resume training, you certainly would need to pass the same configuration. We can't expect that the output is 15 in the second example, if data.destinationaddressid_vocab_size is not passed.

awaelchli avatar Aug 03 '24 15:08 awaelchli

@awaelchli The way I understand it, save_hyperparameters() is not storing the values for parameters that have been linked previously. This was not the case in lightning 2.2.5. This is a problem when restoring from ckpt files for inference. Typically we try to get all the required HP for inference from the ckpt file itself.

peacekurella avatar Aug 05 '24 09:08 peacekurella

I noticed that this is a duplicate of #20311. Even though this issue is older, there is a temporal workaround in https://github.com/Lightning-AI/pytorch-lightning/issues/20311#issuecomment-2442602029.

Additional to the workaround, I created just now pull request #20777 with a potential fix for this. Would be nice if those of you affected review and test it out.

mauvilsa avatar Apr 30 '25 05:04 mauvilsa