pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

Make `save_hyperparameters` consistent for CLI and hardcoded training for custom python objects

Open cgebbe opened this issue 1 year ago • 2 comments

Description & Motivation

Problem

Given the script below, when running running it with lightning CLI, then hparams.yaml becomes

optimizer:
  class_path: __main__.SGD
  init_args:
    params:
    - 1
    - 2
    - 3
    lr: 123.0
myds:
  x: hello
_instantiator: lightning.pytorch.cli.instantiate_module

When running the hardcoded training script instead, hparams.yaml becomes

myds: !!python/object:__main__.MyDataclass
  x: hello
optimizer: !!python/object:__main__.SGD {}

In other words, even though the hyperparameters are the same, hparams.yaml look different. Maybe an alternative question is what's the best practice to define more complex hyperparameters.

Script below

"""
# How to trigger hardcoded training

Comment out `main()` at the very end. Run `python script.py

# How to trigger CLI training

python script.py --config cfg.yaml

model:
  optimizer:
    class_path: SGD
    init_args:
      lr: 123
      params: [1, 2, 3]
"""

import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L
from lightning.pytorch import cli
import torch
import lightning as L
from torch.utils.data import random_split, DataLoader
from torchvision.transforms import v2

# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
from torchvision import transforms

# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

# custom objects
from dataclasses import dataclass
import abc
from typing import Iterable, Callable

from lightning.pytorch.core.mixins import HyperparametersMixin


class Optimizer(abc.ABC):
    def __init__(self, params: Iterable = [1, 2, 3]):
        pass


class SGD(Optimizer):
    def __init__(self, params: Iterable, lr: float):
        super().__init__()


@dataclass
class MyDataclass:
    x: str = "hello"


# define the LightningModule
class LitAutoEncoder(L.LightningModule):
    def __init__(
        self,
        # optimizer: Callable[[], Optimizer],
        optimizer: Optimizer,
        myds: MyDataclass,
    ):
        print(type(optimizer))
        # print(type(optimizer()))

        super().__init__()
        self.save_hyperparameters()

        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


class MNISTDataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = "path/to/dir", batch_size: int = 32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

    def setup(self, stage: str):
        transform = v2.ToTensor()  # mainly to convert PIL to tensor
        self.mnist_test = MNIST(
            self.data_dir,
            train=False,
            download=True,
            transform=transform,
        )
        self.mnist_predict = self.mnist_test
        mnist_full = MNIST(
            self.data_dir,
            train=True,
            download=True,
            transform=transform,
        )
        self.mnist_train, self.mnist_val = random_split(
            mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
        )

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=self.batch_size)


def train_hardcode():
    autoencoder = LitAutoEncoder(
        optimizer=SGD(params=[1, 2, 3], lr=123),
        myds=MyDataclass(),
    )

    # setup data
    dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
    train_loader = utils.data.DataLoader(dataset)

    # train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
    trainer = L.Trainer(limit_train_batches=10, max_epochs=1)
    trainer.fit(model=autoencoder, train_dataloaders=train_loader)


def main():
    cli.LightningCLI(
        LitAutoEncoder,
        MNISTDataModule,
        trainer_defaults=dict(
            max_epochs=1,
            limit_train_batches=10,
            limit_val_batches=10,
        ),
    )


if __name__ == "__main__":
    main()
    # train_hardcode()

Pitch

No response

Alternatives

No response

Additional context

No response

cc @borda

cgebbe avatar Nov 19 '24 14:11 cgebbe

This is because of feature #18105. It is different so that saved checkpoints have all of the information required to instantiate everything from scratch. You can disable this feature by subclassing LightningCLI and adding:

    def _add_instantiators(self):
        pass

With this, hparams.yaml will be consistent, having !!python objects in both cases. I think I will create a pull request with a more official way to disable this feature. Just note that disabling means that, depending on the case load_from_checkpoint might not work correctly. Or even fail because pyyaml is not be able to serialize/deserialize every kind of object.

mauvilsa avatar Apr 27 '25 07:04 mauvilsa

In pull request #20777 I propose to add an official way to disable #18105. It would be by doing LightningCLI(... load_from_checkpoint_support=False).

mauvilsa avatar Apr 30 '25 05:04 mauvilsa

Closing issue, as PR #20777 provides an official way to disable this saving behavior. Please ping and reopen if necessary.

SkafteNicki avatar Sep 02 '25 08:09 SkafteNicki