pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

Save save_hyperparameters no longer respects linked arguments.

Open Erotemic opened this issue 1 year ago • 2 comments

Bug description

As of lightning 2.3.0 save_hyperparameters no longer seems to respect linked arguments.

Based on my investigation this seems to be due to https://github.com/Lightning-AI/pytorch-lightning/pull/18105 which seems to have caused other errors, which were resolved, but as far as I can tell this one persists in the latest 2.4.0 and the master branch 66508ff4b7d49264e37d3e8926fa6e39bcb1217c

What version are you seeing the problem on?

v2.3, v2.4, master

How to reproduce the bug

Save the following script as: lightning_cli_save_hyperaparams_error_on_link_args.py

import torch
import torch.nn
from torch import nn
from torch.utils.data import Dataset
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.cli import LightningCLI
from typing import List, Dict


class MWE_Model(pl.LightningModule):
    """
    Example:
        >>> dataset = MWE_Dataset()
        >>> self = MWE_Model(dataset_stats=dataset.dataset_stats)
        >>> batch = [dataset[i] for i in range(2)]
        >>> self.forward(batch)
    """
    def __init__(self, sorting=False, dataset_stats=None, d_model=16):
        super().__init__()
        self.save_hyperparameters()

        if dataset_stats is None:
            raise ValueError('must be given dataset stats')

        self.d_model = d_model
        self.dataset_stats = dataset_stats

        self.known_sensorchan = {
            (mode['sensor'], mode['channels'], mode['num_bands'])
            for mode in self.dataset_stats['known_modalities']
        }
        self.known_tasks = self.dataset_stats['known_tasks']
        if sorting:
            self.known_sensorchan = sorted(self.known_sensorchan)
            self.known_tasks = sorted(self.known_tasks, key=lambda t: t['name'])

        # Construct stems based on the dataset
        self.stems = torch.nn.ModuleDict()
        for sensor, channels, num_bands in self.known_sensorchan:
            if sensor not in self.stems:
                self.stems[sensor] = torch.nn.ModuleDict()
            self.stems[sensor][channels] = torch.nn.Conv2d(num_bands, self.d_model, kernel_size=1)

        # Backbone is small generic transformer
        self.backbone = torch.nn.Transformer(
            d_model=self.d_model,
            nhead=4,
            num_encoder_layers=2,
            num_decoder_layers=2,
            dim_feedforward=8,
            batch_first=True
        )

        # Construct heads based on the dataset
        self.heads = torch.nn.ModuleDict()
        for head_info in self.known_tasks:
            head_name = head_info['name']
            head_classes = head_info['classes']
            num_classes = len(head_classes)
            self.heads[head_name] = torch.nn.Conv2d(
                self.d_model, num_classes, kernel_size=1)

    @property
    def main_device(self):
        """ Helper to get a device for the model. """
        for key, item in self.state_dict().items():
            return item.device

    def tokenize_inputs(self, item: Dict):
        """
        Process a single batch item's heterogeneous sequence into a flat list
        if tokens for the encoder and decoder.
        """
        device = self.device

        input_sequence = []
        for input_item in item['inputs']:
            stem = self.stems[input_item['sensor_code']][input_item['channel_code']]
            out = stem(input_item['data'])
            tokens = out.view(self.d_model, -1).T
            input_sequence.append(tokens)

        output_sequence = []
        for output_item in item['outputs']:
            shape = tuple(output_item['dims']) + (self.d_model,)
            tokens = torch.rand(shape, device=device).view(-1, self.d_model)
            output_sequence.append(tokens)
        if len(input_sequence) == 0 or len(output_sequence) == 0:
            return None, None
        in_tokens = torch.concat(input_sequence, dim=0)
        out_tokens = torch.concat(output_sequence, dim=0)
        return in_tokens, out_tokens

    def forward(self, batch: List[Dict]) -> List[Dict]:
        """
        Runs prediction on multiple batch items. The input is assumed to an
        uncollated list of dictionaries, each containing information about some
        heterogeneous sequence. The output is a corresponding list of
        dictionaries containing the logits for each head.
        """
        batch_in_tokens = []
        batch_out_tokens = []

        given_batch_size = len(batch)
        valid_batch_indexes = []

        # Prepopulate an output for each input
        batch_logits = [{} for _ in range(given_batch_size)]

        # Handle heterogeneous style inputs on a per-item level
        for batch_idx, item in enumerate(batch):
            in_tokens, out_tokens = self.tokenize_inputs(item)
            if in_tokens is not None:
                valid_batch_indexes.append(batch_idx)
                batch_in_tokens.append(in_tokens)
                batch_out_tokens.append(out_tokens)

        # Some batch items might not be valid
        valid_batch_size = len(valid_batch_indexes)
        if not valid_batch_size:
            # No inputs were valid
            return batch_logits

        # Pad everything into a batch to be more efficient
        padding_value = -9999.0
        input_seqs = nn.utils.rnn.pad_sequence(
            batch_in_tokens,
            batch_first=True,
            padding_value=padding_value,
        )
        output_seqs = nn.utils.rnn.pad_sequence(
            batch_out_tokens,
            batch_first=True,
            padding_value=padding_value,
        )

        input_masks = input_seqs[..., 0] > padding_value
        output_masks = output_seqs[..., 0] > padding_value
        input_seqs[~input_masks] = 0.
        output_seqs[~output_masks] = 0.

        decoded = self.backbone(
            src=input_seqs,
            tgt=output_seqs,
            src_key_padding_mask=~input_masks,
            tgt_key_padding_mask=~output_masks,
        )
        B = valid_batch_size
        # Note output h/w is hardcoded here and uses the fact that the mwe only
        # has one task; could be generalized.
        oh, ow = 3, 3
        decoded_features = decoded.view(B, -1, oh, ow, self.d_model)
        decoded_masks = output_masks.view(B, -1, oh, ow)

        # Reconstruct outputs corresponding to the inputs
        for batch_idx, feat, mask in zip(valid_batch_indexes, decoded_features, decoded_masks):
            item_feat = feat[mask].view(-1, oh, ow, self.d_model).permute(0, 3, 1, 2)
            item_logits = batch_logits[batch_idx]
            for head_name, head_layer in self.heads.items():
                head_logits = head_layer(item_feat)
                item_logits[head_name] = head_logits
        return batch_logits

    def forward_step(self, batch: List[Dict], with_loss=False, stage='unspecified'):
        """
        Generic forward step used for test / train / validation
        """
        batch_logits : List[Dict] = self.forward(batch)
        outputs = {}
        outputs['logits'] = batch_logits

        if with_loss:
            losses = []
            valid_batch_size = 0
            for item, item_logits in zip(batch, batch_logits):
                if len(item_logits):
                    valid_batch_size += 1
                for head_name, head_logits in item_logits.items():
                    head_target = torch.stack([label['data'] for label in item['labels'] if label['head'] == head_name], dim=0)
                    # dummy loss function
                    head_loss = torch.nn.functional.mse_loss(head_logits, head_target)
                    losses.append(head_loss)
            total_loss = sum(losses) if len(losses) > 0 else None
            if total_loss is not None:
                self.log(f'{stage}_loss', total_loss, prog_bar=True, batch_size=valid_batch_size, sync_dist=True)
            outputs['loss'] = total_loss

        return outputs

    def training_step(self, batch, batch_idx=None):
        outputs = self.forward_step(batch, with_loss=True, stage='train')
        if outputs['loss'] is None:
            return None
        return outputs

    def validation_step(self, batch, batch_idx=None):
        outputs = self.forward_step(batch, with_loss=True, stage='val')
        return outputs

    def test_step(self, batch, batch_idx=None):
        outputs = self.forward_step(batch, with_loss=True, stage='test')
        return outputs


class MWE_Dataset(Dataset):
    """
    A dataset that produces heterogeneous outputs

    Example:
        >>> self = MWE_Dataset()
        >>> self[0]
    """
    def __init__(self, max_items_per_epoch=100):
        super().__init__()
        self.max_items_per_epoch = max_items_per_epoch
        self.rng = np.random
        self.dataset_stats =  {
            'known_modalities': [
                {'sensor': 'sensor1', 'channels': 'rgb', 'num_bands': 3, 'dims': (23, 23)},
            ],
            'known_tasks': [
                {'name': 'class', 'classes': ['a', 'b', 'c', 'd', 'e'], 'dims': (3, 3)},
            ]
        }

    def __len__(self):
        return self.max_items_per_epoch

    def __getitem__(self, index) -> Dict:
        """
        Returns:
            Dict: containing
                * inputs - a list of observations
                * outputs - a list of what we want to predict
                * labels - ground truth if we have it
        """
        inputs = []
        outputs = []
        labels = []
        max_timesteps_per_item = 5
        num_frames = max_timesteps_per_item
        p_drop_input = 0

        for frame_index in range(num_frames):
            had_input = 0
            # In general we may have any number of observations per frame
            for modality in self.dataset_stats['known_modalities']:
                sensor = modality['sensor']
                channels = modality['channels']
                c = modality['num_bands']
                h, w = modality['dims']

                # Randomly include each sensorchan on each frame
                if self.rng.rand() >= p_drop_input:
                    had_input = 1
                    inputs.append({
                        'type': 'input',
                        'channel_code': channels,
                        'sensor_code': sensor,
                        'frame_index': frame_index,
                        'data': torch.rand(c, h, w),
                    })
            if had_input:
                for task_info in self.dataset_stats['known_tasks']:
                    task = task_info['name']
                    oh, ow = task_info['dims']
                    oc = len(task_info['classes'])
                    outputs.append({
                        'type': 'output',
                        'head': task,
                        'frame_index': frame_index,
                        'dims': (oh, ow),
                    })
                    labels.append({
                        'type': 'label',
                        'head': task,
                        'frame_index': frame_index,
                        'data': torch.rand(oc, oh, ow),
                    })
        item = {
            'inputs': inputs,
            'outputs': outputs,
            'labels': labels,
        }
        return item

    def make_loader(self, batch_size=1, num_workers=0, shuffle=False,
                    pin_memory=False):
        """
        Create a dataloader option with sensible defaults for the problem
        """
        loader = torch.utils.data.DataLoader(
            self, batch_size=batch_size, num_workers=num_workers,
            shuffle=shuffle, pin_memory=pin_memory,
            collate_fn=lambda x: x
        )
        return loader


class MWE_Datamodule(pl.LightningDataModule):
    def __init__(self, batch_size=1, num_workers=0, max_items_per_epoch=100):
        super().__init__()
        self.save_hyperparameters()
        self.torch_datasets = {}
        self.dataset_stats = None
        self.dataset_kwargs = {
            'max_items_per_epoch': max_items_per_epoch,
        }
        self._did_setup = False

    def setup(self, stage):
        if self._did_setup:
            return
        self.torch_datasets['train'] = MWE_Dataset(**self.dataset_kwargs)
        self.torch_datasets['test'] = MWE_Dataset(**self.dataset_kwargs)
        self.torch_datasets['vali'] = MWE_Dataset(**self.dataset_kwargs)
        self.dataset_stats = self.torch_datasets['train'].dataset_stats
        self._did_setup = True
        print('Setup MWE_Datamodule')
        print(self.__dict__)

    def train_dataloader(self):
        return self._make_dataloader('train', shuffle=True)

    def val_dataloader(self):
        return self._make_dataloader('vali', shuffle=False)

    def test_dataloader(self):
        return self._make_dataloader('test', shuffle=False)

    @property
    def train_dataset(self):
        return self.torch_datasets.get('train', None)

    @property
    def test_dataset(self):
        return self.torch_datasets.get('test', None)

    @property
    def vali_dataset(self):
        return self.torch_datasets.get('vali', None)

    def _make_dataloader(self, stage, shuffle=False):
        loader = self.torch_datasets[stage].make_loader(
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            shuffle=shuffle,
            pin_memory=True,
        )
        return loader


class MWE_LightningCLI(LightningCLI):
    """
    Customized LightningCLI to ensure the expected model inputs / outputs are
    coupled with the what the dataset is able to provide.
    """

    def add_arguments_to_parser(self, parser):
        def data_value_getter(key):
            # Hack to call setup on the datamodule before linking args
            def get_value(data):
                if not data._did_setup:
                    data.setup('fit')
                return getattr(data, key)
            return get_value
        # pass dataset stats to model after datamodule initialization
        parser.link_arguments(
            "data",
            "model.dataset_stats",
            compute_fn=data_value_getter('dataset_stats'),
            apply_on="instantiate")
        super().add_arguments_to_parser(parser)


def main():
    MWE_LightningCLI(
        model_class=MWE_Model,
        datamodule_class=MWE_Datamodule,
    )


if __name__ == '__main__':
    """
    CommandLine:
        cd ~/code/geowatch/dev/mwe/

    """
    main()

Apologies for the length of the MWE, probably could be a few hundred lines shorter, but I had it on hand and it demonstrates the issue well enough. The link_arguments and model init is the important part:

class MWE_LightningCLI(LightningCLI):
    def add_arguments_to_parser(self, parser):
        def data_value_getter(key):
            # Hack to call setup on the datamodule before linking args
            def get_value(data):
                if not data._did_setup:
                    data.setup('fit')
                return getattr(data, key)
            return get_value
        # pass dataset stats to model after datamodule initialization
        parser.link_arguments(
            "data",
            "model.dataset_stats",
            compute_fn=data_value_getter('dataset_stats'),
            apply_on="instantiate")
        super().add_arguments_to_parser(parser)
class MWE_Model(pl.LightningModule):
    def __init__(self, sorting=False, dataset_stats=None, d_model=16):
        super().__init__()
        self.save_hyperparameters()
    ...

Given the above script saved as lightning_cli_save_hyperaparams_error_on_link_args.py, I invoke it as:

DEFAULT_ROOT_DIR=./mwe_train_dir

python lightning_cli_save_hyperaparams_error_on_link_args.py fit --config "
    model:
        sorting: True
    data:
        num_workers: 8
        batch_size: 2
        max_items_per_epoch: 200
    optimizer:
      class_path: torch.optim.Adam
      init_args:
        lr: 1e-7
    trainer:
      default_root_dir     : $DEFAULT_ROOT_DIR
      accelerator          : gpu
      devices              : 1
      max_epochs: 100
"

CKPT_FPATH=$(python -c "import pathlib; print(list(pathlib.Path('$DEFAULT_ROOT_DIR/lightning_logs').glob('*/checkpoints/*.ckpt'))[0])")
HPARAM_FPATH=$(python -c "import pathlib; print(sorted(pathlib.Path('$DEFAULT_ROOT_DIR/lightning_logs').glob('*/hparams.yaml'))[-1])")
cat "$HPARAM_FPATH"

Error messages and logs

When using pytorch_lightning 2.2.5, running:

        HPARAM_FPATH=$(python -c "import pathlib; print(sorted(pathlib.Path('$DEFAULT_ROOT_DIR/lightning_logs').glob('*/hparams.yaml'))[-1])")
        cat "$HPARAM_FPATH"

Correctly prints hyparams that include the dataset_stats linked arguments.

sorting: true
dataset_stats:
  known_modalities:
  - sensor: sensor1
    channels: rgb
    num_bands: 3
    dims:
    - 23
    - 23
  known_tasks:
  - name: class
    classes:
    - a
    - b
    - c
    - d
    - e
    dims:
    - 3
    - 3
d_model: 16
batch_size: 2
num_workers: 8
max_items_per_epoch: 200

But on the latest master branch and 2.4.0 it incorrectly prints:

sorting: true
d_model: 16
_instantiator: pytorch_lightning.cli.instantiate_module
batch_size: 2
num_workers: 8
max_items_per_epoch: 200

Environment

Current environment
  • CUDA:
    • GPU:
      • NVIDIA GeForce RTX 3090
      • NVIDIA GeForce RTX 3090
    • available: True
    • version: 12.4
  • Lightning:
    • lightning: 2.4.0
    • lightning-utilities: 0.11.2
    • perceiver-pytorch: 0.8.3
    • performer-pytorch: 1.0.11
    • pytorch-lightning: 2.4.0
    • pytorch-msssim: 0.1.5
    • pytorch-ranger: 0.1.1
    • reformer-pytorch: 1.4.3
    • torch: 2.4.0+cu124
    • torch-liberator: 0.2.2
    • torch-optimizer: 0.1.0
    • torchaudio: 2.4.0+cu124
    • torchmetrics: 0.11.0
    • torchvision: 0.19.0
  • Packages:
    • absl-py: 1.4.0
    • accelerate: 0.30.1
    • addict: 2.4.0
    • affine: 2.3.0
    • aiobotocore: 2.5.4
    • aiohttp: 3.9.5
    • aiohttp-retry: 2.8.3
    • aioitertools: 0.11.0
    • aiosignal: 1.3.1
    • alabaster: 0.7.16
    • albumentations: 1.0.0
    • amqp: 5.2.0
    • annotated-types: 0.7.0
    • antlr4-python3-runtime: 4.9.3
    • anyio: 4.6.0
    • anytree: 2.12.1
    • appdirs: 1.4.4
    • argcomplete: 3.5.0
    • argo-workflows: 6.5.6
    • arrow: 1.3.0
    • asciitree: 0.3.3
    • astor: 0.8.1
    • astroid: 3.2.2
    • asttokens: 2.4.1
    • astunparse: 1.6.3
    • asyncssh: 2.14.2
    • atomicwrites: 1.4.0
    • atpublic: 4.1.0
    • attrs: 23.2.0
    • auditwheel: 6.1.0
    • autobahn: 24.4.2
    • autodocsumm: 0.2.13
    • automat: 22.10.0
    • autopep8: 2.0.0
    • axial-positional-embedding: 0.2.1
    • babel: 2.15.0
    • backports.tarfile: 1.2.0
    • baron: 0.10.1
    • bashlex: 0.18
    • bcrypt: 4.1.3
    • beautifulsoup4: 4.12.3
    • bidict: 0.23.1
    • billiard: 4.2.0
    • black: 24.4.2
    • blake3: 0.3.1
    • bleach: 6.1.0
    • blinker: 1.8.2
    • boto: 2.49.0
    • boto3: 1.28.17
    • botocore: 1.31.17
    • bpytop: 1.0.68
    • bracex: 2.4
    • brotli: 1.1.0
    • build: 1.2.2
    • cachecontrol: 0.14.0
    • cachetools: 5.4.0
    • celery: 5.4.0
    • certifi: 2024.2.2
    • cffi: 1.16.0
    • cfgv: 3.4.0
    • chardet: 5.2.0
    • charset-normalizer: 2.0.12
    • chromecontroller: 0.3.26
    • cibuildwheel: 2.21.0
    • cleo: 2.1.0
    • click: 8.1.7
    • click-didyoumean: 0.3.1
    • click-plugins: 1.1.1
    • click-repl: 0.3.0
    • cligj: 0.7.2
    • cloudpickle: 3.0.0
    • cmake: 3.29.3
    • cmd-queue: 0.1.21
    • codecarbon: 2.2.4
    • colorama: 0.4.6
    • colormath: 3.0.0
    • colt5-attention: 0.10.20
    • comm: 0.2.2
    • commonmark: 0.9.1
    • configargparse: 1.7
    • configobj: 5.0.8
    • constantly: 23.10.4
    • contourpy: 1.2.1
    • coverage: 7.4.3
    • crashtest: 0.4.1
    • cryptography: 42.0.7
    • cssutils: 2.10.2
    • cycler: 0.12.1
    • cython: 0.29.34
    • dask: 2023.8.1
    • dataframe-image: 0.1.13
    • dataproperty: 1.0.1
    • dbus-python: 1.3.2
    • debugpy: 1.8.2
    • decorator: 5.1.1
    • defusedxml: 0.7.1
    • delayed-image: 0.3.2
    • delorean: 1.0.0
    • detectron2: 0.6
    • diceware: 0.10
    • dictdiffer: 0.9.0
    • diskcache: 5.6.3
    • distinctipy: 1.2.1
    • distlib: 0.3.8
    • distro: 1.9.0
    • docopt: 0.6.2
    • docstring-parser: 0.16
    • docutils: 0.20.1
    • dominate: 2.9.1
    • dpath: 2.1.6
    • dtool-ibeis: 1.1.2
    • dulwich: 0.22.1
    • dvc: 3.51.2
    • dvc-data: 3.15.1
    • dvc-http: 2.32.0
    • dvc-objects: 5.1.0
    • dvc-render: 1.0.2
    • dvc-s3: 3.2.0
    • dvc-ssh: 4.1.1
    • dvc-studio-client: 0.20.0
    • dvc-task: 0.4.0
    • einops: 0.6.0
    • entrypoints: 0.4
    • et-xmlfile: 1.1.0
    • executing: 2.0.1
    • faiss-cpu: 1.8.0
    • fasteners: 0.17.3
    • fastjsonschema: 2.19.1
    • filelock: 3.15.4
    • filterpy: 1.4.5
    • fiona: 1.8.22
    • fire: 0.4.0
    • flake8: 7.0.0
    • flask: 3.0.3
    • flask-basicauth: 0.2.0
    • flask-cors: 3.0.10
    • flask-socketio: 5.3.6
    • flatten-dict: 0.4.2
    • flexcache: 0.3
    • flexparser: 0.3.1
    • flufl.lock: 7.1.1
    • fonttools: 4.51.0
    • frozenlist: 1.4.1
    • fsspec: 2024.6.0
    • funcy: 2.0
    • futures-actors: 0.0.5
    • fuzzywuzzy: 0.18.0
    • fvcore: 0.1.5.post20221221
    • gdal: 3.5.2
    • geodatasets: 2023.12.0
    • geographiclib: 2.0
    • geojson: 3.0.1
    • geomet: 1.1.0
    • geopandas: 0.14.4
    • geopy: 2.4.1
    • geowatch: 0.18.4
    • gevent: 24.2.1
    • girder-client: 3.2.4.dev30+gcacd0e706
    • git-of-theseus: 0.3.4
    • git-python: 1.0.3
    • git-well: 0.2.1
    • gitdb: 4.0.11
    • gitpython: 3.1.43
    • google-api-core: 2.19.0
    • google-api-python-client: 2.130.0
    • google-auth: 2.29.0
    • google-auth-httplib2: 0.2.0
    • google-auth-oauthlib: 1.0.0
    • googleapis-common-protos: 1.63.0
    • grandalf: 0.8
    • graphid: 0.1.0
    • greenlet: 3.0.3
    • grpcio: 1.63.0
    • gto: 1.7.1
    • guitool-ibeis: 2.2.0
    • h11: 0.14.0
    • h3: 3.7.7
    • hardware: 0.31.0
    • hkdf: 0.0.3
    • html2image: 2.0.4.3
    • httpcore: 0.16.3
    • httplib2: 0.22.0
    • httpx: 0.23.3
    • huggingface-hub: 0.23.0
    • humanize: 4.8.0
    • hydra-core: 1.3.2
    • hyperlink: 21.0.0
    • ibeis: 2.3.2
    • identify: 2.6.0
    • idna: 3.7
    • ijson: 3.2.1
    • imageio: 2.34.1
    • imagesize: 1.4.1
    • importlib-metadata: 7.2.1
    • importlib-resources: 6.4.0
    • incremental: 24.7.2
    • iniconfig: 2.0.0
    • installer: 0.7.0
    • instant-rst: 0.9.9.1
    • iopath: 0.1.9
    • ipykernel: 6.29.5
    • ipython: 8.18.1
    • ipython-genutils: 0.2.0
    • isort: 5.13.2
    • iterable-io: 1.0.0
    • iterative-telemetry: 0.0.8
    • itk: 5.4.0
    • itk-core: 5.4.0
    • itk-filtering: 5.4.0
    • itk-io: 5.4.0
    • itk-numerics: 5.4.0
    • itk-registration: 5.4.0
    • itk-segmentation: 5.4.0
    • itsdangerous: 2.2.0
    • jaraco.classes: 3.4.0
    • jaraco.context: 5.3.0
    • jaraco.functools: 4.0.2
    • jedi: 0.19.1
    • jeepney: 0.8.0
    • jellyfin-apiclient-python: 1.9.2
    • jellyfin-migrator: 0.0.0
    • jinja2: 3.1.4
    • jmespath: 1.0.1
    • joblib: 1.4.2
    • johnnydep: 1.20.4
    • jq: 1.7.0
    • jsonargparse: 4.32.1
    • jsonnet: 0.20.0
    • jsonpath: 0.82.2
    • jsonschema: 4.19.2
    • jsonschema-specifications: 2023.12.1
    • jupyter-client: 8.6.1
    • jupyter-core: 5.7.2
    • jupyterlab-pygments: 0.3.0
    • kafka-python: 2.0.2
    • keyring: 24.3.1
    • kiwisolver: 1.4.5
    • kombu: 5.3.7
    • kornia: 0.6.8
    • kornia-rs: 0.1.3
    • kubernetes: 29.0.0
    • kwalop: 0.1.0
    • kwarray: 0.6.19
    • kwcoco: 0.8.5
    • kwcoco-explorer: 0.0.1
    • kwgis: 0.1.1
    • kwimage: 0.10.1
    • kwimage-ext: 0.2.1
    • kwplot: 0.5.2
    • kwutil: 0.3.3
    • lark: 1.1.7
    • lark-cython: 0.0.15
    • lazy-loader: 0.3
    • levenshtein: 0.25.1
    • liberator: 0.1.0
    • lightning: 2.4.0
    • lightning-utilities: 0.11.2
    • line-profiler: 4.1.3
    • linkify-it-py: 2.0.3
    • lit: 18.1.4
    • livereload: 2.7.0
    • llvmlite: 0.42.0
    • local-attention: 1.9.1
    • locket: 1.0.0
    • lockfile: 0.12.2
    • logmatic-python: 0.1.7
    • lxml: 4.9.2
    • magic-wormhole: 0.14.0
    • markdown: 3.6
    • markdown-it-py: 3.0.0
    • markupsafe: 2.1.5
    • mathutf: 0.1.0
    • matplotlib: 3.8.2
    • matplotlib-inline: 0.1.7
    • maturin: 1.7.4
    • mbstrdecoder: 1.1.3
    • mccabe: 0.7.0
    • mdit-py-plugins: 0.4.1
    • mdurl: 0.1.2
    • mgrs: 1.4.6
    • mistune: 3.0.2
    • mkinit: 1.1.0
    • mmcv: 2.0.0
    • mmengine: 0.10.4
    • monai: 0.8.0
    • more-itertools: 8.12.0
    • mpmath: 1.3.0
    • msgpack: 1.0.8
    • multidict: 6.0.5
    • munch: 4.0.0
    • mutagen: 1.47.0
    • mypy: 1.10.0
    • mypy-extensions: 1.0.0
    • myst-parser: 3.0.1
    • nbclient: 0.10.0
    • nbconvert: 7.16.4
    • nbformat: 5.10.4
    • ndsampler: 0.7.9
    • nest-asyncio: 1.6.0
    • netharn: 0.6.2
    • networkx: 3.3
    • networkx-algo-common-subtree: 0.2.1
    • nh3: 0.2.18
    • nodeenv: 1.9.1
    • nrtk: 0.11.0
    • nrtk-explorer: 0.3.0
    • numba: 0.59.1
    • numcodecs: 0.13.0
    • numexpr: 2.8.4
    • numpy: 1.25.2
    • nvidia-cublas-cu11: 11.10.3.66
    • nvidia-cublas-cu12: 12.4.2.65
    • nvidia-cuda-cupti-cu11: 11.7.101
    • nvidia-cuda-cupti-cu12: 12.4.99
    • nvidia-cuda-nvrtc-cu11: 11.7.99
    • nvidia-cuda-nvrtc-cu12: 12.4.99
    • nvidia-cuda-runtime-cu11: 11.7.99
    • nvidia-cuda-runtime-cu12: 12.4.99
    • nvidia-cudnn-cu11: 8.5.0.96
    • nvidia-cudnn-cu12: 9.1.0.70
    • nvidia-cufft-cu11: 10.9.0.58
    • nvidia-cufft-cu12: 11.2.0.44
    • nvidia-curand-cu11: 10.2.10.91
    • nvidia-curand-cu12: 10.3.5.119
    • nvidia-cusolver-cu11: 11.4.0.1
    • nvidia-cusolver-cu12: 11.6.0.99
    • nvidia-cusparse-cu11: 11.7.4.91
    • nvidia-cusparse-cu12: 12.3.0.142
    • nvidia-nccl-cu11: 2.14.3
    • nvidia-nccl-cu12: 2.20.5
    • nvidia-nvjitlink-cu12: 12.4.99
    • nvidia-nvtx-cu11: 11.7.91
    • nvidia-nvtx-cu12: 12.4.99
    • oauthlib: 3.2.2
    • omegaconf: 2.3.0
    • openapi-python-client: 0.20.0
    • openapi-python-generator: 0.5.0
    • openapi-schema-pydantic: 1.2.4
    • opencv-python-headless: 4.10.0.84
    • openpyxl: 3.0.9
    • opentimestamps: 0.4.5
    • opentimestamps-client: 0.7.1
    • ordered-set: 4.1.0
    • orjson: 3.10.3
    • osmnx: 1.9.4
    • oyaml: 1.0
    • packaging: 24.1
    • pandas: 1.5.3
    • pandocfilters: 1.5.1
    • parse: 1.19.0
    • parso: 0.8.4
    • partd: 1.4.2
    • pathspec: 0.12.1
    • pathvalidate: 3.2.1
    • patsy: 0.5.6
    • pbr: 6.0.0
    • pendulum: 3.0.0
    • perceiver-pytorch: 0.8.3
    • performer-pytorch: 1.0.11
    • pexpect: 4.9.0
    • pillow: 10.3.0
    • pint: 0.24.3
    • pip: 24.2
    • pkginfo: 1.10.0
    • platformdirs: 3.11.0
    • plotly: 5.24.0
    • plottool-ibeis: 2.3.0
    • pls-dont-shadow-me: 1.0.0
    • pluggy: 1.5.0
    • pockets: 0.9.1
    • poetry: 1.8.3
    • poetry-core: 1.9.0
    • poetry-plugin-export: 1.8.0
    • pooch: 1.8.2
    • portalocker: 2.10.1
    • portion: 2.4.1
    • pre-commit: 3.8.0
    • prettytable: 3.11.0
    • product-key-memory: 0.2.2
    • progiter: 2.0.0
    • prometheus-client: 0.20.0
    • prompt-toolkit: 3.0.43
    • proto-plus: 1.23.0
    • protobuf: 4.25.3
    • psutil: 5.9.6
    • psycopg2-binary: 2.9.5
    • ptyprocess: 0.7.0
    • pure-eval: 0.2.2
    • purepy-root-demo-pkg-lzcsutvo: 1.0.0
    • purepy-src-demo-pkg: 1.0.0
    • purepy-src-demo-pkg-dbrmcjpb: 1.0.0
    • purepy-src-demo-pkg-lzcsutvo: 1.0.0
    • py-cpuinfo: 9.0.0
    • pyasn1: 0.6.0
    • pyasn1-modules: 0.4.0
    • pybsm: 0.5.1
    • pycocotools: 2.0.7
    • pycodestyle: 2.11.1
    • pycparser: 2.22
    • pycryptodomex: 3.20.0
    • pydantic: 2.7.1
    • pydantic-core: 2.18.2
    • pydot: 2.0.0
    • pyelftools: 0.31
    • pyfiglet: 1.0.2
    • pyflakes: 3.2.0
    • pyflann-ibeis: 2.4.0
    • pygame: 2.6.0
    • pygit2: 1.15.0
    • pygments: 2.18.0
    • pygraphviz: 1.13
    • pygtrie: 2.5.0
    • pyhesaff: 2.1.1
    • pylatex: 0+untagged.769.gb48e8ec
    • pylatexenc: 3.0a29
    • pymongo: 3.13.0
    • pynacl: 1.5.0
    • pynmea2: 1.19.0
    • pynndescent: 0.5.12
    • pynvim: 0.5.0
    • pynvml: 11.5.0
    • pyo3-example: 0.1.0
    • pyopenssl: 24.1.0
    • pyparsing: 3.1.2
    • pyperclip: 1.8.2
    • pypistats: 1.6.0
    • pypng: 0.20220715.0
    • pypogo: 0.1.0
    • pyproj: 3.4.1
    • pyproject-api: 1.7.1
    • pyproject-hooks: 1.1.0
    • pyqrcode: 1.2.1
    • pyqt5: 5.15.10
    • pyqt5-qt5: 5.15.2
    • pyqt5-sip: 12.13.0
    • pyqtree: 1.0.0
    • pysocks: 1.7.1
    • pystac: 1.10.1
    • pystac-client: 0.8.1
    • pytablewriter: 1.2.0
    • pytest: 8.0.2
    • pytest-cov: 5.0.0
    • pytest-subtests: 0.13.1
    • python-bitcoinlib: 0.12.2
    • python-dateutil: 2.9.0.post1.dev3+g9eaa5de
    • python-engineio: 4.9.1
    • python-gitlab: 4.6.0
    • python-json-logger: 2.0.7
    • python-levenshtein: 0.25.1
    • python-slugify: 8.0.4
    • python-socketio: 5.11.3
    • pytimeparse: 1.1.8
    • pytorch-lightning: 2.4.0
    • pytorch-msssim: 0.1.5
    • pytorch-ranger: 0.1.1
    • pytz: 2024.1
    • pywavelets: 1.6.0
    • pyyaml: 6.0.1
    • pyzmq: 26.0.3
    • quantities: 0.15.0
    • rapidfuzz: 3.9.1
    • rasterio: 1.3.10
    • readme-renderer: 44.0
    • reconplogger: 4.16.1
    • redbaron: 0.9.2
    • referencing: 0.35.1
    • reformer-pytorch: 1.4.3
    • regex: 2024.5.10
    • requests: 2.32.2
    • requests-oauthlib: 2.0.0
    • requests-toolbelt: 1.0.0
    • responses: 0.25.3
    • rfc3986: 1.5.0
    • rgd-client: 0.2.7
    • rgd-imagery-client: 0.2.7
    • rich: 12.5.1
    • rich-argparse: 1.1.0
    • rpds-py: 0.18.1
    • rply: 0.7.8
    • rsa: 4.9
    • rtree: 1.0.1
    • ruamel.yaml: 0.17.32
    • ruamel.yaml.clib: 0.2.8
    • ruff: 0.4.5
    • ruyaml: 0.91.0
    • s3fs: 2024.6.0
    • s3transfer: 0.6.2
    • s5cmd: 0.2.0
    • safer: 4.12.3
    • safetensors: 0.4.3
    • scikit-build: 0.17.6
    • scikit-image: 0.21.0
    • scikit-learn: 1.5.1
    • scipy: 1.14.0
    • scmrepo: 3.3.5
    • scriptconfig: 0.7.16
    • seaborn: 0.13.2
    • secretstorage: 3.3.3
    • semver: 3.0.2
    • service-identity: 24.1.0
    • setuptools: 67.7.2
    • shapely: 2.0.1
    • shellingham: 1.5.4
    • shitspotter: 0.0.1
    • shortuuid: 1.0.13
    • shtab: 1.7.1
    • simple-dvc: 0.2.2
    • simple-websocket: 1.0.0
    • simpleitk: 2.3.1
    • simplejson: 3.19.2
    • simplekml: 1.3.3
    • six: 1.16.0
    • smartflow: 3.1.3
    • smmap: 5.0.1
    • smqtk-classifier: 0.19.0
    • smqtk-core: 0.19.0
    • smqtk-dataprovider: 0.18.0
    • smqtk-descriptors: 0.19.0
    • smqtk-detection: 0.20.1
    • smqtk-image-io: 0.17.1
    • smqtk-indexing: 0.18.0
    • smqtk-iqr: 0.15.1
    • smqtk-relevancy: 0.17.0
    • sniffio: 1.3.1
    • snowballstemmer: 2.2.0
    • snuggs: 1.4.7
    • sortedcontainers: 2.4.0
    • soupsieve: 2.5
    • spake2: 0.8
    • sphinx: 7.3.7
    • sphinx-autoapi: 3.1.1
    • sphinx-autobuild: 2024.4.16
    • sphinx-autodoc-typehints: 2.3.0
    • sphinx-reredirects: 0.1.3
    • sphinx-rtd-theme: 2.0.0
    • sphinxcontrib-applehelp: 1.0.8
    • sphinxcontrib-devhelp: 1.0.6
    • sphinxcontrib-htmlhelp: 2.0.5
    • sphinxcontrib-jquery: 4.1
    • sphinxcontrib-jsmath: 1.0.1
    • sphinxcontrib-napoleon: 0.7
    • sphinxcontrib-qthelp: 1.0.7
    • sphinxcontrib-serializinghtml: 1.1.10
    • sqlalchemy: 1.4.50
    • sqlalchemy-utils: 0.41.2
    • sqltrie: 0.11.0
    • sshfs: 2024.4.1
    • stack-data: 0.6.3
    • starlette: 0.37.2
    • statsmodels: 0.14.2
    • structlog: 24.2.0
    • sympy: 1.12
    • tabledata: 1.3.3
    • tabulate: 0.9.0
    • tcolorpy: 0.1.6
    • tempenv: 0.2.0
    • tenacity: 9.0.0
    • tensorboard: 2.14.0
    • tensorboard-data-server: 0.7.2
    • tensorrt-bindings: 8.6.1
    • tensorrt-cu12: 10.0.1
    • tensorrt-cu12-bindings: 10.0.1
    • tensorrt-cu12-libs: 10.0.1
    • tensorrt-libs: 8.6.1
    • termcolor: 2.4.0
    • text-unidecode: 1.3
    • textual: 0.1.18
    • threadpoolctl: 3.5.0
    • tifffile: 2024.5.22
    • timerit: 1.1.0
    • timezonefinder: 6.5.2
    • timm: 0.6.13
    • tinycss2: 1.3.0
    • tokenizers: 0.15.2
    • toml: 0.10.2
    • tomli: 2.0.1
    • tomlkit: 0.12.5
    • toolz: 0.12.1
    • torch: 2.4.0+cu124
    • torch-liberator: 0.2.2
    • torch-optimizer: 0.1.0
    • torchaudio: 2.4.0+cu124
    • torchmetrics: 0.11.0
    • torchvision: 0.19.0
    • tornado: 6.4
    • tox: 4.17.1
    • tqdm: 4.64.1
    • traitlets: 5.14.3
    • trame: 3.6.1
    • trame-client: 3.0.3
    • trame-plotly: 3.0.2
    • trame-quasar: 0.2.1
    • trame-server: 3.0.1
    • trame-vuetify: 2.7.0
    • transformers: 4.37.2
    • triton: 3.0.0
    • trove-classifiers: 2024.9.12
    • twine: 5.1.1
    • twisted: 24.3.0
    • txaio: 23.1.1
    • txtorcon: 23.11.0
    • typepy: 1.3.2
    • typer: 0.12.3
    • types-python-dateutil: 2.9.0.20240316
    • types-pyyaml: 6.0.12.20240808
    • types-requests: 2.32.0.20240907
    • types-setuptools: 70.0.0.20240524
    • typeshed-client: 2.5.1
    • typing-extensions: 4.11.0
    • tzdata: 2024.1
    • tzlocal: 5.2
    • ubelt: 1.3.6
    • uc-micro-py: 1.0.3
    • ujson: 5.6.0
    • umap-learn: 0.5.6
    • uncertainties: 3.2.2
    • uritemplate: 4.1.1
    • uritools: 4.0.2
    • urllib3: 1.26.20
    • utm: 0.7.0
    • utool: 2.2.0
    • uv: 0.3.4
    • uvicorn: 0.29.0
    • validators: 0.28.1
    • vimtk: 0.5.0
    • vine: 5.1.0
    • virtualenv: 20.26.3
    • voluptuous: 0.14.2
    • vtool-ibeis: 2.3.0
    • vtool-ibeis-ext: 0.1.1
    • watchfiles: 0.21.0
    • wcmatch: 8.5.2
    • wcwidth: 0.2.13
    • webencodings: 0.5.1
    • websocket-client: 1.8.0
    • websockets: 12.0
    • werkzeug: 3.0.4
    • wheel: 0.40.0
    • wimpy: 0.6
    • wrapt: 1.14.1
    • wslink: 2.0.4
    • wsproto: 1.2.0
    • xarray: 0.17.0
    • xcookie: 0.2.2
    • xdev: 1.5.2
    • xdoctest: 1.1.5
    • xinspect: 0.2.0
    • xmltodict: 0.12.0
    • xxhash: 3.4.1
    • yacs: 0.1.8
    • yapf: 0.40.2
    • yarl: 1.9.4
    • yt-dlp: 2024.8.6
    • zarr: 2.18.2
    • zc.lockfile: 3.0.post1
    • zipp: 3.18.1
    • zipstream-ng: 1.7.1
    • zope.event: 5.0
    • zope.interface: 6.4.post2
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.11.9
    • release: 6.8.0-45-generic
    • version: #45-Ubuntu SMP PREEMPT_DYNAMIC Fri Aug 30 12:02:04 UTC 2024

More info

No response

Erotemic avatar Sep 30 '24 20:09 Erotemic

Was the reproduction script supposed to be attached, but it isn't?

mauvilsa avatar Oct 05 '24 11:10 mauvilsa

The MWE is in the details. You can click the arrow to expand it. For convenience here it is in a gist as well: https://gist.github.com/Erotemic/dfdbf192004486e9f108b0334dd7fdcd

Erotemic avatar Oct 05 '24 16:10 Erotemic

I am also affected by this issue... :(

noamsgl avatar Oct 27 '24 14:10 noamsgl

Now I understand what the problem is. Later I will think about a proper solution and create a pull request. For the time being to disable the current behavior you can implement the following in your LightningCLI subclass:

    def _add_instantiators(self):
        pass

mauvilsa avatar Oct 28 '24 20:10 mauvilsa

This behavior might have been introduced with the given_hparams mechanism in the new implementation of save_hyperparameters(): https://github.com/Lightning-AI/pytorch-lightning/blob/1129d4cecf2bcbf78d9340655c3950e744a019e1/src/lightning/pytorch/core/mixins/hparams_mixin.py#L125-L131

huangyxi avatar Dec 12 '24 03:12 huangyxi

This behavior might have been introduced with the given_hparams

@huangyxi yes, it is because of that, which is a change needed for the load_from_checkpoint support from #18105 as mentioned in the bug description. Note that the issue only affects links applied on instantiate.

When I have the time I will work on a fix for this. I mentioned a workaround in my comment above.

mauvilsa avatar Dec 17 '24 15:12 mauvilsa

@mauvilsa , thanks for the quick workaround. Can it be though that it this slows down the logging substantially? I am using the standard CSVLogger. After adding the workaround the training easily hangs for 20s in the steps when logs are written. Usually, I wouldn't even notice the logging. I struggle to see a direct connection.

HenrikAsmuth avatar Jan 06 '25 16:01 HenrikAsmuth

I don't think that would slow down logging. There is no relation to it.

mauvilsa avatar Jan 06 '25 16:01 mauvilsa

Does this issue fixed in newest version? I am also affected by this issue... :(

WhenMelancholy avatar Apr 10 '25 10:04 WhenMelancholy

No, this issue has not been fixed yet. You have the workaround in https://github.com/Lightning-AI/pytorch-lightning/issues/20311#issuecomment-2442602029

mauvilsa avatar Apr 10 '25 14:04 mauvilsa

Now I understand what the problem is. Later I will think about a proper solution and create a pull request. For the time being to disable the current behavior you can implement the following in your LightningCLI subclass:

def _add_instantiators(self):
    pass

I got a similar problem as well, and you fix finally worked, thank you @mauvilsa ! In the meantime, maybe you could add a warning/reference in the documentation about it ? (I spent a lot of time looking in the documentation and trying to debug line by line in the code + searching through the forum etc. until I found this thread)

OrianeN avatar Apr 11 '25 14:04 OrianeN

The workaround mentioned in https://github.com/Lightning-AI/pytorch-lightning/issues/20311#issuecomment-2442602029 causes my training to hang indefinitely in yaml.dump called by _log_hyperparams.

  File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/cli.py", line 398, in __init__
    self._run_subcommand(self.subcommand)
  File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/cli.py", line 708, in _run_subcommand
    fn(**fn_kwargs)
  File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 561, in fit
    call._call_and_handle_interrupt(
  File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/call.py", line 48, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 599, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 995, in _run
    _log_hyperparams(self)
  File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loggers/utilities.py", line 102, in _log_hyperparams
    logger.save()
  File "/usr/local/lib/python3.12/dist-packages/lightning_utilities/core/rank_zero.py", line 41, in wrapped_fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loggers/tensorboard.py", line 222, in save
    save_hparams_to_yaml(hparams_file, self.hparams)
  File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/core/saving.py", line 361, in save_hparams_to_yaml
    yaml.dump(v)
  File "/usr/local/lib/python3.12/dist-packages/yaml/__init__.py", line 253, in dump
    return dump_all([data], stream, Dumper=Dumper, **kwds)
  File "/usr/local/lib/python3.12/dist-packages/yaml/__init__.py", line 241, in dump_all
    dumper.represent(data)
  File "/usr/local/lib/python3.12/dist-packages/yaml/representer.py", line 28, in represent
    self.serialize(node)
  File "/usr/local/lib/python3.12/dist-packages/yaml/serializer.py", line 54, in serialize
    self.serialize_node(node, None, None)
  File "/usr/local/lib/python3.12/dist-packages/yaml/serializer.py", line 108, in serialize_node
    self.serialize_node(value, node, key)
  File "/usr/local/lib/python3.12/dist-packages/yaml/serializer.py", line 98, in serialize_node
    self.serialize_node(item, node, index)
  File "/usr/local/lib/python3.12/dist-packages/yaml/serializer.py", line 100, in serialize_node
    self.emit(SequenceEndEvent())
  File "/usr/local/lib/python3.12/dist-packages/yaml/emitter.py", line 115, in emit
    self.state()
  File "/usr/local/lib/python3.12/dist-packages/yaml/emitter.py", line 374, in expect_first_block_sequence_item
    return self.expect_block_sequence_item(first=True)
  File "/usr/local/lib/python3.12/dist-packages/yaml/emitter.py", line 384, in expect_block_sequence_item
    self.expect_node(sequence=True)
  File "/usr/local/lib/python3.12/dist-packages/yaml/emitter.py", line 242, in expect_node
    self.process_tag()
  File "/usr/local/lib/python3.12/dist-packages/yaml/emitter.py", line 473, in process_tag
    self.style = self.choose_scalar_style()
  File "/usr/local/lib/python3.12/dist-packages/yaml/emitter.py", line 496, in choose_scalar_style
    self.analysis = self.analyze_scalar(self.event.value)
  File "/usr/local/lib/python3.12/dist-packages/yaml/emitter.py", line 668, in analyze_scalar
    while index < len(scalar):

E1k3 avatar Apr 18 '25 11:04 E1k3

@E1k3 with the workaround, the feature from #18105 is completely disabled. This means that save_hyperparameters tries to save whatever __init__ received, which might not be serializable using pyyaml. I guess this could be the problem that you are having.

mauvilsa avatar Apr 20 '25 15:04 mauvilsa

I created pull request #20777 with a potential fix for this. Would be nice if those of you affected review and test it out.

mauvilsa avatar Apr 30 '25 05:04 mauvilsa

I also find this question and it made me have to revert my pytorch lightning.... and one quick way to solve it is pass cfg without linked arguments (deeper dict etc.)

detail comment here: https://github.com/Lightning-AI/pytorch-lightning/issues/17558#issuecomment-2842777482

Kin-Zhang avatar Apr 30 '25 17:04 Kin-Zhang