Save save_hyperparameters no longer respects linked arguments.
Bug description
As of lightning 2.3.0 save_hyperparameters no longer seems to respect linked arguments.
Based on my investigation this seems to be due to https://github.com/Lightning-AI/pytorch-lightning/pull/18105 which seems to have caused other errors, which were resolved, but as far as I can tell this one persists in the latest 2.4.0 and the master branch 66508ff4b7d49264e37d3e8926fa6e39bcb1217c
What version are you seeing the problem on?
v2.3, v2.4, master
How to reproduce the bug
Save the following script as: lightning_cli_save_hyperaparams_error_on_link_args.py
import torch
import torch.nn
from torch import nn
from torch.utils.data import Dataset
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.cli import LightningCLI
from typing import List, Dict
class MWE_Model(pl.LightningModule):
"""
Example:
>>> dataset = MWE_Dataset()
>>> self = MWE_Model(dataset_stats=dataset.dataset_stats)
>>> batch = [dataset[i] for i in range(2)]
>>> self.forward(batch)
"""
def __init__(self, sorting=False, dataset_stats=None, d_model=16):
super().__init__()
self.save_hyperparameters()
if dataset_stats is None:
raise ValueError('must be given dataset stats')
self.d_model = d_model
self.dataset_stats = dataset_stats
self.known_sensorchan = {
(mode['sensor'], mode['channels'], mode['num_bands'])
for mode in self.dataset_stats['known_modalities']
}
self.known_tasks = self.dataset_stats['known_tasks']
if sorting:
self.known_sensorchan = sorted(self.known_sensorchan)
self.known_tasks = sorted(self.known_tasks, key=lambda t: t['name'])
# Construct stems based on the dataset
self.stems = torch.nn.ModuleDict()
for sensor, channels, num_bands in self.known_sensorchan:
if sensor not in self.stems:
self.stems[sensor] = torch.nn.ModuleDict()
self.stems[sensor][channels] = torch.nn.Conv2d(num_bands, self.d_model, kernel_size=1)
# Backbone is small generic transformer
self.backbone = torch.nn.Transformer(
d_model=self.d_model,
nhead=4,
num_encoder_layers=2,
num_decoder_layers=2,
dim_feedforward=8,
batch_first=True
)
# Construct heads based on the dataset
self.heads = torch.nn.ModuleDict()
for head_info in self.known_tasks:
head_name = head_info['name']
head_classes = head_info['classes']
num_classes = len(head_classes)
self.heads[head_name] = torch.nn.Conv2d(
self.d_model, num_classes, kernel_size=1)
@property
def main_device(self):
""" Helper to get a device for the model. """
for key, item in self.state_dict().items():
return item.device
def tokenize_inputs(self, item: Dict):
"""
Process a single batch item's heterogeneous sequence into a flat list
if tokens for the encoder and decoder.
"""
device = self.device
input_sequence = []
for input_item in item['inputs']:
stem = self.stems[input_item['sensor_code']][input_item['channel_code']]
out = stem(input_item['data'])
tokens = out.view(self.d_model, -1).T
input_sequence.append(tokens)
output_sequence = []
for output_item in item['outputs']:
shape = tuple(output_item['dims']) + (self.d_model,)
tokens = torch.rand(shape, device=device).view(-1, self.d_model)
output_sequence.append(tokens)
if len(input_sequence) == 0 or len(output_sequence) == 0:
return None, None
in_tokens = torch.concat(input_sequence, dim=0)
out_tokens = torch.concat(output_sequence, dim=0)
return in_tokens, out_tokens
def forward(self, batch: List[Dict]) -> List[Dict]:
"""
Runs prediction on multiple batch items. The input is assumed to an
uncollated list of dictionaries, each containing information about some
heterogeneous sequence. The output is a corresponding list of
dictionaries containing the logits for each head.
"""
batch_in_tokens = []
batch_out_tokens = []
given_batch_size = len(batch)
valid_batch_indexes = []
# Prepopulate an output for each input
batch_logits = [{} for _ in range(given_batch_size)]
# Handle heterogeneous style inputs on a per-item level
for batch_idx, item in enumerate(batch):
in_tokens, out_tokens = self.tokenize_inputs(item)
if in_tokens is not None:
valid_batch_indexes.append(batch_idx)
batch_in_tokens.append(in_tokens)
batch_out_tokens.append(out_tokens)
# Some batch items might not be valid
valid_batch_size = len(valid_batch_indexes)
if not valid_batch_size:
# No inputs were valid
return batch_logits
# Pad everything into a batch to be more efficient
padding_value = -9999.0
input_seqs = nn.utils.rnn.pad_sequence(
batch_in_tokens,
batch_first=True,
padding_value=padding_value,
)
output_seqs = nn.utils.rnn.pad_sequence(
batch_out_tokens,
batch_first=True,
padding_value=padding_value,
)
input_masks = input_seqs[..., 0] > padding_value
output_masks = output_seqs[..., 0] > padding_value
input_seqs[~input_masks] = 0.
output_seqs[~output_masks] = 0.
decoded = self.backbone(
src=input_seqs,
tgt=output_seqs,
src_key_padding_mask=~input_masks,
tgt_key_padding_mask=~output_masks,
)
B = valid_batch_size
# Note output h/w is hardcoded here and uses the fact that the mwe only
# has one task; could be generalized.
oh, ow = 3, 3
decoded_features = decoded.view(B, -1, oh, ow, self.d_model)
decoded_masks = output_masks.view(B, -1, oh, ow)
# Reconstruct outputs corresponding to the inputs
for batch_idx, feat, mask in zip(valid_batch_indexes, decoded_features, decoded_masks):
item_feat = feat[mask].view(-1, oh, ow, self.d_model).permute(0, 3, 1, 2)
item_logits = batch_logits[batch_idx]
for head_name, head_layer in self.heads.items():
head_logits = head_layer(item_feat)
item_logits[head_name] = head_logits
return batch_logits
def forward_step(self, batch: List[Dict], with_loss=False, stage='unspecified'):
"""
Generic forward step used for test / train / validation
"""
batch_logits : List[Dict] = self.forward(batch)
outputs = {}
outputs['logits'] = batch_logits
if with_loss:
losses = []
valid_batch_size = 0
for item, item_logits in zip(batch, batch_logits):
if len(item_logits):
valid_batch_size += 1
for head_name, head_logits in item_logits.items():
head_target = torch.stack([label['data'] for label in item['labels'] if label['head'] == head_name], dim=0)
# dummy loss function
head_loss = torch.nn.functional.mse_loss(head_logits, head_target)
losses.append(head_loss)
total_loss = sum(losses) if len(losses) > 0 else None
if total_loss is not None:
self.log(f'{stage}_loss', total_loss, prog_bar=True, batch_size=valid_batch_size, sync_dist=True)
outputs['loss'] = total_loss
return outputs
def training_step(self, batch, batch_idx=None):
outputs = self.forward_step(batch, with_loss=True, stage='train')
if outputs['loss'] is None:
return None
return outputs
def validation_step(self, batch, batch_idx=None):
outputs = self.forward_step(batch, with_loss=True, stage='val')
return outputs
def test_step(self, batch, batch_idx=None):
outputs = self.forward_step(batch, with_loss=True, stage='test')
return outputs
class MWE_Dataset(Dataset):
"""
A dataset that produces heterogeneous outputs
Example:
>>> self = MWE_Dataset()
>>> self[0]
"""
def __init__(self, max_items_per_epoch=100):
super().__init__()
self.max_items_per_epoch = max_items_per_epoch
self.rng = np.random
self.dataset_stats = {
'known_modalities': [
{'sensor': 'sensor1', 'channels': 'rgb', 'num_bands': 3, 'dims': (23, 23)},
],
'known_tasks': [
{'name': 'class', 'classes': ['a', 'b', 'c', 'd', 'e'], 'dims': (3, 3)},
]
}
def __len__(self):
return self.max_items_per_epoch
def __getitem__(self, index) -> Dict:
"""
Returns:
Dict: containing
* inputs - a list of observations
* outputs - a list of what we want to predict
* labels - ground truth if we have it
"""
inputs = []
outputs = []
labels = []
max_timesteps_per_item = 5
num_frames = max_timesteps_per_item
p_drop_input = 0
for frame_index in range(num_frames):
had_input = 0
# In general we may have any number of observations per frame
for modality in self.dataset_stats['known_modalities']:
sensor = modality['sensor']
channels = modality['channels']
c = modality['num_bands']
h, w = modality['dims']
# Randomly include each sensorchan on each frame
if self.rng.rand() >= p_drop_input:
had_input = 1
inputs.append({
'type': 'input',
'channel_code': channels,
'sensor_code': sensor,
'frame_index': frame_index,
'data': torch.rand(c, h, w),
})
if had_input:
for task_info in self.dataset_stats['known_tasks']:
task = task_info['name']
oh, ow = task_info['dims']
oc = len(task_info['classes'])
outputs.append({
'type': 'output',
'head': task,
'frame_index': frame_index,
'dims': (oh, ow),
})
labels.append({
'type': 'label',
'head': task,
'frame_index': frame_index,
'data': torch.rand(oc, oh, ow),
})
item = {
'inputs': inputs,
'outputs': outputs,
'labels': labels,
}
return item
def make_loader(self, batch_size=1, num_workers=0, shuffle=False,
pin_memory=False):
"""
Create a dataloader option with sensible defaults for the problem
"""
loader = torch.utils.data.DataLoader(
self, batch_size=batch_size, num_workers=num_workers,
shuffle=shuffle, pin_memory=pin_memory,
collate_fn=lambda x: x
)
return loader
class MWE_Datamodule(pl.LightningDataModule):
def __init__(self, batch_size=1, num_workers=0, max_items_per_epoch=100):
super().__init__()
self.save_hyperparameters()
self.torch_datasets = {}
self.dataset_stats = None
self.dataset_kwargs = {
'max_items_per_epoch': max_items_per_epoch,
}
self._did_setup = False
def setup(self, stage):
if self._did_setup:
return
self.torch_datasets['train'] = MWE_Dataset(**self.dataset_kwargs)
self.torch_datasets['test'] = MWE_Dataset(**self.dataset_kwargs)
self.torch_datasets['vali'] = MWE_Dataset(**self.dataset_kwargs)
self.dataset_stats = self.torch_datasets['train'].dataset_stats
self._did_setup = True
print('Setup MWE_Datamodule')
print(self.__dict__)
def train_dataloader(self):
return self._make_dataloader('train', shuffle=True)
def val_dataloader(self):
return self._make_dataloader('vali', shuffle=False)
def test_dataloader(self):
return self._make_dataloader('test', shuffle=False)
@property
def train_dataset(self):
return self.torch_datasets.get('train', None)
@property
def test_dataset(self):
return self.torch_datasets.get('test', None)
@property
def vali_dataset(self):
return self.torch_datasets.get('vali', None)
def _make_dataloader(self, stage, shuffle=False):
loader = self.torch_datasets[stage].make_loader(
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
shuffle=shuffle,
pin_memory=True,
)
return loader
class MWE_LightningCLI(LightningCLI):
"""
Customized LightningCLI to ensure the expected model inputs / outputs are
coupled with the what the dataset is able to provide.
"""
def add_arguments_to_parser(self, parser):
def data_value_getter(key):
# Hack to call setup on the datamodule before linking args
def get_value(data):
if not data._did_setup:
data.setup('fit')
return getattr(data, key)
return get_value
# pass dataset stats to model after datamodule initialization
parser.link_arguments(
"data",
"model.dataset_stats",
compute_fn=data_value_getter('dataset_stats'),
apply_on="instantiate")
super().add_arguments_to_parser(parser)
def main():
MWE_LightningCLI(
model_class=MWE_Model,
datamodule_class=MWE_Datamodule,
)
if __name__ == '__main__':
"""
CommandLine:
cd ~/code/geowatch/dev/mwe/
"""
main()
Apologies for the length of the MWE, probably could be a few hundred lines shorter, but I had it on hand and it demonstrates the issue well enough. The link_arguments and model init is the important part:
class MWE_LightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
def data_value_getter(key):
# Hack to call setup on the datamodule before linking args
def get_value(data):
if not data._did_setup:
data.setup('fit')
return getattr(data, key)
return get_value
# pass dataset stats to model after datamodule initialization
parser.link_arguments(
"data",
"model.dataset_stats",
compute_fn=data_value_getter('dataset_stats'),
apply_on="instantiate")
super().add_arguments_to_parser(parser)
class MWE_Model(pl.LightningModule):
def __init__(self, sorting=False, dataset_stats=None, d_model=16):
super().__init__()
self.save_hyperparameters()
...
Given the above script saved as lightning_cli_save_hyperaparams_error_on_link_args.py, I invoke it as:
DEFAULT_ROOT_DIR=./mwe_train_dir
python lightning_cli_save_hyperaparams_error_on_link_args.py fit --config "
model:
sorting: True
data:
num_workers: 8
batch_size: 2
max_items_per_epoch: 200
optimizer:
class_path: torch.optim.Adam
init_args:
lr: 1e-7
trainer:
default_root_dir : $DEFAULT_ROOT_DIR
accelerator : gpu
devices : 1
max_epochs: 100
"
CKPT_FPATH=$(python -c "import pathlib; print(list(pathlib.Path('$DEFAULT_ROOT_DIR/lightning_logs').glob('*/checkpoints/*.ckpt'))[0])")
HPARAM_FPATH=$(python -c "import pathlib; print(sorted(pathlib.Path('$DEFAULT_ROOT_DIR/lightning_logs').glob('*/hparams.yaml'))[-1])")
cat "$HPARAM_FPATH"
Error messages and logs
When using pytorch_lightning 2.2.5, running:
HPARAM_FPATH=$(python -c "import pathlib; print(sorted(pathlib.Path('$DEFAULT_ROOT_DIR/lightning_logs').glob('*/hparams.yaml'))[-1])")
cat "$HPARAM_FPATH"
Correctly prints hyparams that include the dataset_stats linked arguments.
sorting: true
dataset_stats:
known_modalities:
- sensor: sensor1
channels: rgb
num_bands: 3
dims:
- 23
- 23
known_tasks:
- name: class
classes:
- a
- b
- c
- d
- e
dims:
- 3
- 3
d_model: 16
batch_size: 2
num_workers: 8
max_items_per_epoch: 200
But on the latest master branch and 2.4.0 it incorrectly prints:
sorting: true
d_model: 16
_instantiator: pytorch_lightning.cli.instantiate_module
batch_size: 2
num_workers: 8
max_items_per_epoch: 200
Environment
Current environment
- CUDA:
- GPU:
- NVIDIA GeForce RTX 3090
- NVIDIA GeForce RTX 3090
- available: True
- version: 12.4
- GPU:
- Lightning:
- lightning: 2.4.0
- lightning-utilities: 0.11.2
- perceiver-pytorch: 0.8.3
- performer-pytorch: 1.0.11
- pytorch-lightning: 2.4.0
- pytorch-msssim: 0.1.5
- pytorch-ranger: 0.1.1
- reformer-pytorch: 1.4.3
- torch: 2.4.0+cu124
- torch-liberator: 0.2.2
- torch-optimizer: 0.1.0
- torchaudio: 2.4.0+cu124
- torchmetrics: 0.11.0
- torchvision: 0.19.0
- Packages:
- absl-py: 1.4.0
- accelerate: 0.30.1
- addict: 2.4.0
- affine: 2.3.0
- aiobotocore: 2.5.4
- aiohttp: 3.9.5
- aiohttp-retry: 2.8.3
- aioitertools: 0.11.0
- aiosignal: 1.3.1
- alabaster: 0.7.16
- albumentations: 1.0.0
- amqp: 5.2.0
- annotated-types: 0.7.0
- antlr4-python3-runtime: 4.9.3
- anyio: 4.6.0
- anytree: 2.12.1
- appdirs: 1.4.4
- argcomplete: 3.5.0
- argo-workflows: 6.5.6
- arrow: 1.3.0
- asciitree: 0.3.3
- astor: 0.8.1
- astroid: 3.2.2
- asttokens: 2.4.1
- astunparse: 1.6.3
- asyncssh: 2.14.2
- atomicwrites: 1.4.0
- atpublic: 4.1.0
- attrs: 23.2.0
- auditwheel: 6.1.0
- autobahn: 24.4.2
- autodocsumm: 0.2.13
- automat: 22.10.0
- autopep8: 2.0.0
- axial-positional-embedding: 0.2.1
- babel: 2.15.0
- backports.tarfile: 1.2.0
- baron: 0.10.1
- bashlex: 0.18
- bcrypt: 4.1.3
- beautifulsoup4: 4.12.3
- bidict: 0.23.1
- billiard: 4.2.0
- black: 24.4.2
- blake3: 0.3.1
- bleach: 6.1.0
- blinker: 1.8.2
- boto: 2.49.0
- boto3: 1.28.17
- botocore: 1.31.17
- bpytop: 1.0.68
- bracex: 2.4
- brotli: 1.1.0
- build: 1.2.2
- cachecontrol: 0.14.0
- cachetools: 5.4.0
- celery: 5.4.0
- certifi: 2024.2.2
- cffi: 1.16.0
- cfgv: 3.4.0
- chardet: 5.2.0
- charset-normalizer: 2.0.12
- chromecontroller: 0.3.26
- cibuildwheel: 2.21.0
- cleo: 2.1.0
- click: 8.1.7
- click-didyoumean: 0.3.1
- click-plugins: 1.1.1
- click-repl: 0.3.0
- cligj: 0.7.2
- cloudpickle: 3.0.0
- cmake: 3.29.3
- cmd-queue: 0.1.21
- codecarbon: 2.2.4
- colorama: 0.4.6
- colormath: 3.0.0
- colt5-attention: 0.10.20
- comm: 0.2.2
- commonmark: 0.9.1
- configargparse: 1.7
- configobj: 5.0.8
- constantly: 23.10.4
- contourpy: 1.2.1
- coverage: 7.4.3
- crashtest: 0.4.1
- cryptography: 42.0.7
- cssutils: 2.10.2
- cycler: 0.12.1
- cython: 0.29.34
- dask: 2023.8.1
- dataframe-image: 0.1.13
- dataproperty: 1.0.1
- dbus-python: 1.3.2
- debugpy: 1.8.2
- decorator: 5.1.1
- defusedxml: 0.7.1
- delayed-image: 0.3.2
- delorean: 1.0.0
- detectron2: 0.6
- diceware: 0.10
- dictdiffer: 0.9.0
- diskcache: 5.6.3
- distinctipy: 1.2.1
- distlib: 0.3.8
- distro: 1.9.0
- docopt: 0.6.2
- docstring-parser: 0.16
- docutils: 0.20.1
- dominate: 2.9.1
- dpath: 2.1.6
- dtool-ibeis: 1.1.2
- dulwich: 0.22.1
- dvc: 3.51.2
- dvc-data: 3.15.1
- dvc-http: 2.32.0
- dvc-objects: 5.1.0
- dvc-render: 1.0.2
- dvc-s3: 3.2.0
- dvc-ssh: 4.1.1
- dvc-studio-client: 0.20.0
- dvc-task: 0.4.0
- einops: 0.6.0
- entrypoints: 0.4
- et-xmlfile: 1.1.0
- executing: 2.0.1
- faiss-cpu: 1.8.0
- fasteners: 0.17.3
- fastjsonschema: 2.19.1
- filelock: 3.15.4
- filterpy: 1.4.5
- fiona: 1.8.22
- fire: 0.4.0
- flake8: 7.0.0
- flask: 3.0.3
- flask-basicauth: 0.2.0
- flask-cors: 3.0.10
- flask-socketio: 5.3.6
- flatten-dict: 0.4.2
- flexcache: 0.3
- flexparser: 0.3.1
- flufl.lock: 7.1.1
- fonttools: 4.51.0
- frozenlist: 1.4.1
- fsspec: 2024.6.0
- funcy: 2.0
- futures-actors: 0.0.5
- fuzzywuzzy: 0.18.0
- fvcore: 0.1.5.post20221221
- gdal: 3.5.2
- geodatasets: 2023.12.0
- geographiclib: 2.0
- geojson: 3.0.1
- geomet: 1.1.0
- geopandas: 0.14.4
- geopy: 2.4.1
- geowatch: 0.18.4
- gevent: 24.2.1
- girder-client: 3.2.4.dev30+gcacd0e706
- git-of-theseus: 0.3.4
- git-python: 1.0.3
- git-well: 0.2.1
- gitdb: 4.0.11
- gitpython: 3.1.43
- google-api-core: 2.19.0
- google-api-python-client: 2.130.0
- google-auth: 2.29.0
- google-auth-httplib2: 0.2.0
- google-auth-oauthlib: 1.0.0
- googleapis-common-protos: 1.63.0
- grandalf: 0.8
- graphid: 0.1.0
- greenlet: 3.0.3
- grpcio: 1.63.0
- gto: 1.7.1
- guitool-ibeis: 2.2.0
- h11: 0.14.0
- h3: 3.7.7
- hardware: 0.31.0
- hkdf: 0.0.3
- html2image: 2.0.4.3
- httpcore: 0.16.3
- httplib2: 0.22.0
- httpx: 0.23.3
- huggingface-hub: 0.23.0
- humanize: 4.8.0
- hydra-core: 1.3.2
- hyperlink: 21.0.0
- ibeis: 2.3.2
- identify: 2.6.0
- idna: 3.7
- ijson: 3.2.1
- imageio: 2.34.1
- imagesize: 1.4.1
- importlib-metadata: 7.2.1
- importlib-resources: 6.4.0
- incremental: 24.7.2
- iniconfig: 2.0.0
- installer: 0.7.0
- instant-rst: 0.9.9.1
- iopath: 0.1.9
- ipykernel: 6.29.5
- ipython: 8.18.1
- ipython-genutils: 0.2.0
- isort: 5.13.2
- iterable-io: 1.0.0
- iterative-telemetry: 0.0.8
- itk: 5.4.0
- itk-core: 5.4.0
- itk-filtering: 5.4.0
- itk-io: 5.4.0
- itk-numerics: 5.4.0
- itk-registration: 5.4.0
- itk-segmentation: 5.4.0
- itsdangerous: 2.2.0
- jaraco.classes: 3.4.0
- jaraco.context: 5.3.0
- jaraco.functools: 4.0.2
- jedi: 0.19.1
- jeepney: 0.8.0
- jellyfin-apiclient-python: 1.9.2
- jellyfin-migrator: 0.0.0
- jinja2: 3.1.4
- jmespath: 1.0.1
- joblib: 1.4.2
- johnnydep: 1.20.4
- jq: 1.7.0
- jsonargparse: 4.32.1
- jsonnet: 0.20.0
- jsonpath: 0.82.2
- jsonschema: 4.19.2
- jsonschema-specifications: 2023.12.1
- jupyter-client: 8.6.1
- jupyter-core: 5.7.2
- jupyterlab-pygments: 0.3.0
- kafka-python: 2.0.2
- keyring: 24.3.1
- kiwisolver: 1.4.5
- kombu: 5.3.7
- kornia: 0.6.8
- kornia-rs: 0.1.3
- kubernetes: 29.0.0
- kwalop: 0.1.0
- kwarray: 0.6.19
- kwcoco: 0.8.5
- kwcoco-explorer: 0.0.1
- kwgis: 0.1.1
- kwimage: 0.10.1
- kwimage-ext: 0.2.1
- kwplot: 0.5.2
- kwutil: 0.3.3
- lark: 1.1.7
- lark-cython: 0.0.15
- lazy-loader: 0.3
- levenshtein: 0.25.1
- liberator: 0.1.0
- lightning: 2.4.0
- lightning-utilities: 0.11.2
- line-profiler: 4.1.3
- linkify-it-py: 2.0.3
- lit: 18.1.4
- livereload: 2.7.0
- llvmlite: 0.42.0
- local-attention: 1.9.1
- locket: 1.0.0
- lockfile: 0.12.2
- logmatic-python: 0.1.7
- lxml: 4.9.2
- magic-wormhole: 0.14.0
- markdown: 3.6
- markdown-it-py: 3.0.0
- markupsafe: 2.1.5
- mathutf: 0.1.0
- matplotlib: 3.8.2
- matplotlib-inline: 0.1.7
- maturin: 1.7.4
- mbstrdecoder: 1.1.3
- mccabe: 0.7.0
- mdit-py-plugins: 0.4.1
- mdurl: 0.1.2
- mgrs: 1.4.6
- mistune: 3.0.2
- mkinit: 1.1.0
- mmcv: 2.0.0
- mmengine: 0.10.4
- monai: 0.8.0
- more-itertools: 8.12.0
- mpmath: 1.3.0
- msgpack: 1.0.8
- multidict: 6.0.5
- munch: 4.0.0
- mutagen: 1.47.0
- mypy: 1.10.0
- mypy-extensions: 1.0.0
- myst-parser: 3.0.1
- nbclient: 0.10.0
- nbconvert: 7.16.4
- nbformat: 5.10.4
- ndsampler: 0.7.9
- nest-asyncio: 1.6.0
- netharn: 0.6.2
- networkx: 3.3
- networkx-algo-common-subtree: 0.2.1
- nh3: 0.2.18
- nodeenv: 1.9.1
- nrtk: 0.11.0
- nrtk-explorer: 0.3.0
- numba: 0.59.1
- numcodecs: 0.13.0
- numexpr: 2.8.4
- numpy: 1.25.2
- nvidia-cublas-cu11: 11.10.3.66
- nvidia-cublas-cu12: 12.4.2.65
- nvidia-cuda-cupti-cu11: 11.7.101
- nvidia-cuda-cupti-cu12: 12.4.99
- nvidia-cuda-nvrtc-cu11: 11.7.99
- nvidia-cuda-nvrtc-cu12: 12.4.99
- nvidia-cuda-runtime-cu11: 11.7.99
- nvidia-cuda-runtime-cu12: 12.4.99
- nvidia-cudnn-cu11: 8.5.0.96
- nvidia-cudnn-cu12: 9.1.0.70
- nvidia-cufft-cu11: 10.9.0.58
- nvidia-cufft-cu12: 11.2.0.44
- nvidia-curand-cu11: 10.2.10.91
- nvidia-curand-cu12: 10.3.5.119
- nvidia-cusolver-cu11: 11.4.0.1
- nvidia-cusolver-cu12: 11.6.0.99
- nvidia-cusparse-cu11: 11.7.4.91
- nvidia-cusparse-cu12: 12.3.0.142
- nvidia-nccl-cu11: 2.14.3
- nvidia-nccl-cu12: 2.20.5
- nvidia-nvjitlink-cu12: 12.4.99
- nvidia-nvtx-cu11: 11.7.91
- nvidia-nvtx-cu12: 12.4.99
- oauthlib: 3.2.2
- omegaconf: 2.3.0
- openapi-python-client: 0.20.0
- openapi-python-generator: 0.5.0
- openapi-schema-pydantic: 1.2.4
- opencv-python-headless: 4.10.0.84
- openpyxl: 3.0.9
- opentimestamps: 0.4.5
- opentimestamps-client: 0.7.1
- ordered-set: 4.1.0
- orjson: 3.10.3
- osmnx: 1.9.4
- oyaml: 1.0
- packaging: 24.1
- pandas: 1.5.3
- pandocfilters: 1.5.1
- parse: 1.19.0
- parso: 0.8.4
- partd: 1.4.2
- pathspec: 0.12.1
- pathvalidate: 3.2.1
- patsy: 0.5.6
- pbr: 6.0.0
- pendulum: 3.0.0
- perceiver-pytorch: 0.8.3
- performer-pytorch: 1.0.11
- pexpect: 4.9.0
- pillow: 10.3.0
- pint: 0.24.3
- pip: 24.2
- pkginfo: 1.10.0
- platformdirs: 3.11.0
- plotly: 5.24.0
- plottool-ibeis: 2.3.0
- pls-dont-shadow-me: 1.0.0
- pluggy: 1.5.0
- pockets: 0.9.1
- poetry: 1.8.3
- poetry-core: 1.9.0
- poetry-plugin-export: 1.8.0
- pooch: 1.8.2
- portalocker: 2.10.1
- portion: 2.4.1
- pre-commit: 3.8.0
- prettytable: 3.11.0
- product-key-memory: 0.2.2
- progiter: 2.0.0
- prometheus-client: 0.20.0
- prompt-toolkit: 3.0.43
- proto-plus: 1.23.0
- protobuf: 4.25.3
- psutil: 5.9.6
- psycopg2-binary: 2.9.5
- ptyprocess: 0.7.0
- pure-eval: 0.2.2
- purepy-root-demo-pkg-lzcsutvo: 1.0.0
- purepy-src-demo-pkg: 1.0.0
- purepy-src-demo-pkg-dbrmcjpb: 1.0.0
- purepy-src-demo-pkg-lzcsutvo: 1.0.0
- py-cpuinfo: 9.0.0
- pyasn1: 0.6.0
- pyasn1-modules: 0.4.0
- pybsm: 0.5.1
- pycocotools: 2.0.7
- pycodestyle: 2.11.1
- pycparser: 2.22
- pycryptodomex: 3.20.0
- pydantic: 2.7.1
- pydantic-core: 2.18.2
- pydot: 2.0.0
- pyelftools: 0.31
- pyfiglet: 1.0.2
- pyflakes: 3.2.0
- pyflann-ibeis: 2.4.0
- pygame: 2.6.0
- pygit2: 1.15.0
- pygments: 2.18.0
- pygraphviz: 1.13
- pygtrie: 2.5.0
- pyhesaff: 2.1.1
- pylatex: 0+untagged.769.gb48e8ec
- pylatexenc: 3.0a29
- pymongo: 3.13.0
- pynacl: 1.5.0
- pynmea2: 1.19.0
- pynndescent: 0.5.12
- pynvim: 0.5.0
- pynvml: 11.5.0
- pyo3-example: 0.1.0
- pyopenssl: 24.1.0
- pyparsing: 3.1.2
- pyperclip: 1.8.2
- pypistats: 1.6.0
- pypng: 0.20220715.0
- pypogo: 0.1.0
- pyproj: 3.4.1
- pyproject-api: 1.7.1
- pyproject-hooks: 1.1.0
- pyqrcode: 1.2.1
- pyqt5: 5.15.10
- pyqt5-qt5: 5.15.2
- pyqt5-sip: 12.13.0
- pyqtree: 1.0.0
- pysocks: 1.7.1
- pystac: 1.10.1
- pystac-client: 0.8.1
- pytablewriter: 1.2.0
- pytest: 8.0.2
- pytest-cov: 5.0.0
- pytest-subtests: 0.13.1
- python-bitcoinlib: 0.12.2
- python-dateutil: 2.9.0.post1.dev3+g9eaa5de
- python-engineio: 4.9.1
- python-gitlab: 4.6.0
- python-json-logger: 2.0.7
- python-levenshtein: 0.25.1
- python-slugify: 8.0.4
- python-socketio: 5.11.3
- pytimeparse: 1.1.8
- pytorch-lightning: 2.4.0
- pytorch-msssim: 0.1.5
- pytorch-ranger: 0.1.1
- pytz: 2024.1
- pywavelets: 1.6.0
- pyyaml: 6.0.1
- pyzmq: 26.0.3
- quantities: 0.15.0
- rapidfuzz: 3.9.1
- rasterio: 1.3.10
- readme-renderer: 44.0
- reconplogger: 4.16.1
- redbaron: 0.9.2
- referencing: 0.35.1
- reformer-pytorch: 1.4.3
- regex: 2024.5.10
- requests: 2.32.2
- requests-oauthlib: 2.0.0
- requests-toolbelt: 1.0.0
- responses: 0.25.3
- rfc3986: 1.5.0
- rgd-client: 0.2.7
- rgd-imagery-client: 0.2.7
- rich: 12.5.1
- rich-argparse: 1.1.0
- rpds-py: 0.18.1
- rply: 0.7.8
- rsa: 4.9
- rtree: 1.0.1
- ruamel.yaml: 0.17.32
- ruamel.yaml.clib: 0.2.8
- ruff: 0.4.5
- ruyaml: 0.91.0
- s3fs: 2024.6.0
- s3transfer: 0.6.2
- s5cmd: 0.2.0
- safer: 4.12.3
- safetensors: 0.4.3
- scikit-build: 0.17.6
- scikit-image: 0.21.0
- scikit-learn: 1.5.1
- scipy: 1.14.0
- scmrepo: 3.3.5
- scriptconfig: 0.7.16
- seaborn: 0.13.2
- secretstorage: 3.3.3
- semver: 3.0.2
- service-identity: 24.1.0
- setuptools: 67.7.2
- shapely: 2.0.1
- shellingham: 1.5.4
- shitspotter: 0.0.1
- shortuuid: 1.0.13
- shtab: 1.7.1
- simple-dvc: 0.2.2
- simple-websocket: 1.0.0
- simpleitk: 2.3.1
- simplejson: 3.19.2
- simplekml: 1.3.3
- six: 1.16.0
- smartflow: 3.1.3
- smmap: 5.0.1
- smqtk-classifier: 0.19.0
- smqtk-core: 0.19.0
- smqtk-dataprovider: 0.18.0
- smqtk-descriptors: 0.19.0
- smqtk-detection: 0.20.1
- smqtk-image-io: 0.17.1
- smqtk-indexing: 0.18.0
- smqtk-iqr: 0.15.1
- smqtk-relevancy: 0.17.0
- sniffio: 1.3.1
- snowballstemmer: 2.2.0
- snuggs: 1.4.7
- sortedcontainers: 2.4.0
- soupsieve: 2.5
- spake2: 0.8
- sphinx: 7.3.7
- sphinx-autoapi: 3.1.1
- sphinx-autobuild: 2024.4.16
- sphinx-autodoc-typehints: 2.3.0
- sphinx-reredirects: 0.1.3
- sphinx-rtd-theme: 2.0.0
- sphinxcontrib-applehelp: 1.0.8
- sphinxcontrib-devhelp: 1.0.6
- sphinxcontrib-htmlhelp: 2.0.5
- sphinxcontrib-jquery: 4.1
- sphinxcontrib-jsmath: 1.0.1
- sphinxcontrib-napoleon: 0.7
- sphinxcontrib-qthelp: 1.0.7
- sphinxcontrib-serializinghtml: 1.1.10
- sqlalchemy: 1.4.50
- sqlalchemy-utils: 0.41.2
- sqltrie: 0.11.0
- sshfs: 2024.4.1
- stack-data: 0.6.3
- starlette: 0.37.2
- statsmodels: 0.14.2
- structlog: 24.2.0
- sympy: 1.12
- tabledata: 1.3.3
- tabulate: 0.9.0
- tcolorpy: 0.1.6
- tempenv: 0.2.0
- tenacity: 9.0.0
- tensorboard: 2.14.0
- tensorboard-data-server: 0.7.2
- tensorrt-bindings: 8.6.1
- tensorrt-cu12: 10.0.1
- tensorrt-cu12-bindings: 10.0.1
- tensorrt-cu12-libs: 10.0.1
- tensorrt-libs: 8.6.1
- termcolor: 2.4.0
- text-unidecode: 1.3
- textual: 0.1.18
- threadpoolctl: 3.5.0
- tifffile: 2024.5.22
- timerit: 1.1.0
- timezonefinder: 6.5.2
- timm: 0.6.13
- tinycss2: 1.3.0
- tokenizers: 0.15.2
- toml: 0.10.2
- tomli: 2.0.1
- tomlkit: 0.12.5
- toolz: 0.12.1
- torch: 2.4.0+cu124
- torch-liberator: 0.2.2
- torch-optimizer: 0.1.0
- torchaudio: 2.4.0+cu124
- torchmetrics: 0.11.0
- torchvision: 0.19.0
- tornado: 6.4
- tox: 4.17.1
- tqdm: 4.64.1
- traitlets: 5.14.3
- trame: 3.6.1
- trame-client: 3.0.3
- trame-plotly: 3.0.2
- trame-quasar: 0.2.1
- trame-server: 3.0.1
- trame-vuetify: 2.7.0
- transformers: 4.37.2
- triton: 3.0.0
- trove-classifiers: 2024.9.12
- twine: 5.1.1
- twisted: 24.3.0
- txaio: 23.1.1
- txtorcon: 23.11.0
- typepy: 1.3.2
- typer: 0.12.3
- types-python-dateutil: 2.9.0.20240316
- types-pyyaml: 6.0.12.20240808
- types-requests: 2.32.0.20240907
- types-setuptools: 70.0.0.20240524
- typeshed-client: 2.5.1
- typing-extensions: 4.11.0
- tzdata: 2024.1
- tzlocal: 5.2
- ubelt: 1.3.6
- uc-micro-py: 1.0.3
- ujson: 5.6.0
- umap-learn: 0.5.6
- uncertainties: 3.2.2
- uritemplate: 4.1.1
- uritools: 4.0.2
- urllib3: 1.26.20
- utm: 0.7.0
- utool: 2.2.0
- uv: 0.3.4
- uvicorn: 0.29.0
- validators: 0.28.1
- vimtk: 0.5.0
- vine: 5.1.0
- virtualenv: 20.26.3
- voluptuous: 0.14.2
- vtool-ibeis: 2.3.0
- vtool-ibeis-ext: 0.1.1
- watchfiles: 0.21.0
- wcmatch: 8.5.2
- wcwidth: 0.2.13
- webencodings: 0.5.1
- websocket-client: 1.8.0
- websockets: 12.0
- werkzeug: 3.0.4
- wheel: 0.40.0
- wimpy: 0.6
- wrapt: 1.14.1
- wslink: 2.0.4
- wsproto: 1.2.0
- xarray: 0.17.0
- xcookie: 0.2.2
- xdev: 1.5.2
- xdoctest: 1.1.5
- xinspect: 0.2.0
- xmltodict: 0.12.0
- xxhash: 3.4.1
- yacs: 0.1.8
- yapf: 0.40.2
- yarl: 1.9.4
- yt-dlp: 2024.8.6
- zarr: 2.18.2
- zc.lockfile: 3.0.post1
- zipp: 3.18.1
- zipstream-ng: 1.7.1
- zope.event: 5.0
- zope.interface: 6.4.post2
- System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.11.9
- release: 6.8.0-45-generic
- version: #45-Ubuntu SMP PREEMPT_DYNAMIC Fri Aug 30 12:02:04 UTC 2024
More info
No response
Was the reproduction script supposed to be attached, but it isn't?
The MWE is in the details. You can click the arrow to expand it. For convenience here it is in a gist as well: https://gist.github.com/Erotemic/dfdbf192004486e9f108b0334dd7fdcd
I am also affected by this issue... :(
Now I understand what the problem is. Later I will think about a proper solution and create a pull request. For the time being to disable the current behavior you can implement the following in your LightningCLI subclass:
def _add_instantiators(self):
pass
This behavior might have been introduced with the given_hparams mechanism in the new implementation of save_hyperparameters():
https://github.com/Lightning-AI/pytorch-lightning/blob/1129d4cecf2bcbf78d9340655c3950e744a019e1/src/lightning/pytorch/core/mixins/hparams_mixin.py#L125-L131
This behavior might have been introduced with the given_hparams
@huangyxi yes, it is because of that, which is a change needed for the load_from_checkpoint support from #18105 as mentioned in the bug description. Note that the issue only affects links applied on instantiate.
When I have the time I will work on a fix for this. I mentioned a workaround in my comment above.
@mauvilsa , thanks for the quick workaround. Can it be though that it this slows down the logging substantially? I am using the standard CSVLogger. After adding the workaround the training easily hangs for 20s in the steps when logs are written. Usually, I wouldn't even notice the logging. I struggle to see a direct connection.
I don't think that would slow down logging. There is no relation to it.
Does this issue fixed in newest version? I am also affected by this issue... :(
No, this issue has not been fixed yet. You have the workaround in https://github.com/Lightning-AI/pytorch-lightning/issues/20311#issuecomment-2442602029
Now I understand what the problem is. Later I will think about a proper solution and create a pull request. For the time being to disable the current behavior you can implement the following in your
LightningCLIsubclass:def _add_instantiators(self): pass
I got a similar problem as well, and you fix finally worked, thank you @mauvilsa ! In the meantime, maybe you could add a warning/reference in the documentation about it ? (I spent a lot of time looking in the documentation and trying to debug line by line in the code + searching through the forum etc. until I found this thread)
The workaround mentioned in https://github.com/Lightning-AI/pytorch-lightning/issues/20311#issuecomment-2442602029 causes my training to hang indefinitely in yaml.dump called by _log_hyperparams.
File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/cli.py", line 398, in __init__
self._run_subcommand(self.subcommand)
File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/cli.py", line 708, in _run_subcommand
fn(**fn_kwargs)
File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 561, in fit
call._call_and_handle_interrupt(
File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/call.py", line 48, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 599, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 995, in _run
_log_hyperparams(self)
File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loggers/utilities.py", line 102, in _log_hyperparams
logger.save()
File "/usr/local/lib/python3.12/dist-packages/lightning_utilities/core/rank_zero.py", line 41, in wrapped_fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loggers/tensorboard.py", line 222, in save
save_hparams_to_yaml(hparams_file, self.hparams)
File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/core/saving.py", line 361, in save_hparams_to_yaml
yaml.dump(v)
File "/usr/local/lib/python3.12/dist-packages/yaml/__init__.py", line 253, in dump
return dump_all([data], stream, Dumper=Dumper, **kwds)
File "/usr/local/lib/python3.12/dist-packages/yaml/__init__.py", line 241, in dump_all
dumper.represent(data)
File "/usr/local/lib/python3.12/dist-packages/yaml/representer.py", line 28, in represent
self.serialize(node)
File "/usr/local/lib/python3.12/dist-packages/yaml/serializer.py", line 54, in serialize
self.serialize_node(node, None, None)
File "/usr/local/lib/python3.12/dist-packages/yaml/serializer.py", line 108, in serialize_node
self.serialize_node(value, node, key)
File "/usr/local/lib/python3.12/dist-packages/yaml/serializer.py", line 98, in serialize_node
self.serialize_node(item, node, index)
File "/usr/local/lib/python3.12/dist-packages/yaml/serializer.py", line 100, in serialize_node
self.emit(SequenceEndEvent())
File "/usr/local/lib/python3.12/dist-packages/yaml/emitter.py", line 115, in emit
self.state()
File "/usr/local/lib/python3.12/dist-packages/yaml/emitter.py", line 374, in expect_first_block_sequence_item
return self.expect_block_sequence_item(first=True)
File "/usr/local/lib/python3.12/dist-packages/yaml/emitter.py", line 384, in expect_block_sequence_item
self.expect_node(sequence=True)
File "/usr/local/lib/python3.12/dist-packages/yaml/emitter.py", line 242, in expect_node
self.process_tag()
File "/usr/local/lib/python3.12/dist-packages/yaml/emitter.py", line 473, in process_tag
self.style = self.choose_scalar_style()
File "/usr/local/lib/python3.12/dist-packages/yaml/emitter.py", line 496, in choose_scalar_style
self.analysis = self.analyze_scalar(self.event.value)
File "/usr/local/lib/python3.12/dist-packages/yaml/emitter.py", line 668, in analyze_scalar
while index < len(scalar):
@E1k3 with the workaround, the feature from #18105 is completely disabled. This means that save_hyperparameters tries to save whatever __init__ received, which might not be serializable using pyyaml. I guess this could be the problem that you are having.
I created pull request #20777 with a potential fix for this. Would be nice if those of you affected review and test it out.
I also find this question and it made me have to revert my pytorch lightning.... and one quick way to solve it is pass cfg without linked arguments (deeper dict etc.)
detail comment here: https://github.com/Lightning-AI/pytorch-lightning/issues/17558#issuecomment-2842777482