pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

trainer.validate() get different result from trainer.fit

Open matrix72c opened this issue 1 year ago • 1 comments

Bug description

I'm training a ResNet50 model and using model checkpoint(only weight) to save the best model. However I found the results are different in fit and validate. I start training with lightning clli and set run=False, just manually call fit and validate.

What version are you seeing the problem on?

v2.2

How to reproduce the bug

Main code


import lightning as L
import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from torchmetrics import Accuracy
from lightning.pytorch.cli import LightningCLI


class ResNet(L.LightningModule):
    def __init__(
        self,
        num_classes: int,
        use_pretrained: bool = True,
        lr: float = 1e-3,
        step_size: int = 10,
        gamma: float = 0.1,
    ):
        super().__init__()
        self.save_hyperparameters()

        self.model = torchvision.models.resnet50(
            weights=(
                torchvision.models.ResNet50_Weights.DEFAULT if use_pretrained else None
            ),
        )
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
        self.acc = Accuracy(task="multiclass", num_classes=num_classes)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=self.hparams.step_size, gamma=self.hparams.gamma
        )
        return [optimizer], [scheduler]

    def forward(self, x):
        x = self.model(x)
        return x

    def shared_step(self, batch):
        img, label, _ = batch
        logits = self(img)
        loss = F.cross_entropy(logits, label)
        self.acc(logits, label)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch)
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", self.acc, prog_bar=True, on_step=True, on_epoch=False)
        return loss

    def validation_step(self, batch, batch_idx):
        _ = self.shared_step(batch)
        self.log("val_acc", self.acc, prog_bar=True, on_epoch=True, on_step=True)

    def test_step(self, batch, batch_idx):
        _ = self.shared_step(batch)
        self.log("test_acc", self.acc, prog_bar=True, on_epoch=True, on_step=True)

if __name__ == "__main__":
    torch.set_float32_matmul_precision("high")
    cli = LightningCLI(save_config_callback=None, run=False)

    cli.trainer.fit(cli.model, cli.datamodule)
    cli.trainer.validate(cli.model, cli.datamodule, ckpt_path="checkpoint/resnet_CUB.ckpt")

config file

seed_everything: 42

trainer:
  max_epochs: 1000
  log_every_n_steps: 1
  logger:
    class_path: aim.pytorch_lightning.AimLogger
    init_args:
      run_name: resnet_CUB
  callbacks:
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        monitor: val_acc
        dirpath: checkpoints/
        filename: resnet_CUB
        save_top_k: 1
        mode: max
        save_weights_only: True
        enable_version_counter: False
    - class_path: lightning.pytorch.callbacks.EarlyStopping
      init_args:
        monitor: val_acc
        mode: max
        min_delta: 0.001
        patience: 0

model:
  class_path: model.ResNet
  init_args:
    num_classes: 200

data:
  class_path: dataset.CUB
  init_args:
    data_path: ./data
    batch_size: 128

Error messages and logs

In fit, the final acc_epoch is 0.86, while in trainer.validate, it becomes 0.74.

Environment

Current environment
  • CUDA: - GPU: - NVIDIA GeForce RTX 4090 - available: True - version: 11.8
  • Lightning: - lightning: 2.2.5 - lightning-utilities: 0.11.2 - pytorch-lightning: 2.2.2 - torch: 2.1.1 - torchattacks: 3.5.1 - torchaudio: 2.1.1 - torchmetrics: 1.4.0 - torchvision: 0.16.1
  • Packages: - absl-py: 2.1.0 - aim: 3.22.0 - aim-ui: 3.22.0 - aimrecords: 0.0.7 - aimrocks: 0.5.2 - aiofiles: 23.2.1 - aiohttp: 3.9.5 - aiosignal: 1.3.1 - alembic: 1.13.0 - annotated-types: 0.6.0 - antlr4-python3-runtime: 4.9.3 - anyio: 3.7.1 - argcomplete: 3.4.0 - argon2-cffi: 23.1.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.3.0 - asttokens: 2.4.1 - astunparse: 1.6.3 - async-lru: 2.0.4 - async-timeout: 4.0.3 - attrs: 23.1.0 - babel: 2.14.0 - backoff: 2.2.1 - base58: 2.0.1 - beautifulsoup4: 4.12.3 - bitsandbytes: 0.41.0 - black: 24.2.0 - bleach: 6.1.0 - boto3: 1.34.62 - botocore: 1.34.62 - bottleneck: 1.3.5 - brotli: 1.0.9 - cachetools: 5.3.2 - certifi: 2024.7.4 - cffi: 1.16.0 - chardet: 4.0.0 - charset-normalizer: 2.0.4 - click: 8.1.7 - colorama: 0.4.6 - comm: 0.2.1 - contourpy: 1.2.0 - cryptography: 41.0.7 - cycler: 0.11.0 - debugpy: 1.6.7 - decorator: 5.1.1 - defusedxml: 0.7.1 - docopt: 0.6.2 - docstring-parser: 0.16 - exceptiongroup: 1.2.0 - executing: 2.0.1 - fastapi: 0.104.1 - fastjsonschema: 2.19.1 - filelock: 3.13.1 - fonttools: 4.25.0 - fqdn: 1.5.1 - frozenlist: 1.4.1 - fsspec: 2023.10.0 - gmpy2: 2.1.2 - greenlet: 3.0.2 - grpcio: 1.48.2 - h11: 0.14.0 - httpcore: 1.0.4 - httpx: 0.27.0 - hydra-core: 1.3.2 - idna: 2.10 - importlib-metadata: 7.0.1 - importlib-resources: 6.1.0 - ipyflow: 0.0.198 - ipyflow-core: 0.0.198 - ipykernel: 6.29.0 - ipython: 8.18.1 - ipython-genutils: 0.2.0 - ipywidgets: 8.1.2 - isoduration: 20.11.0 - jedi: 0.19.1 - jinja2: 3.1.2 - jmespath: 1.0.1 - joblib: 1.3.2 - json5: 0.9.17 - jsonargparse: 4.29.0 - jsonnet: 0.17.0 - jsonpointer: 2.4 - jsonschema: 4.19.2 - jsonschema-specifications: 2023.12.1 - jupyter: 1.0.0 - jupyter-client: 8.6.0 - jupyter-console: 6.6.3 - jupyter-core: 5.7.1 - jupyter-events: 0.9.0 - jupyter-lsp: 2.2.3 - jupyter-server: 2.12.5 - jupyter-server-terminals: 0.5.2 - jupyterlab: 4.1.2 - jupyterlab-pygments: 0.3.0 - jupyterlab-server: 2.25.3 - jupyterlab-widgets: 3.0.10 - kiwisolver: 1.4.4 - lightning: 2.2.5 - lightning-utilities: 0.11.2 - mako: 1.3.0 - markdown: 3.6 - markdown-it-py: 3.0.0 - markupsafe: 2.1.1 - matplotlib: 3.8.0 - matplotlib-inline: 0.1.6 - mdurl: 0.1.2 - mistune: 3.0.2 - mkl-fft: 1.3.8 - mkl-random: 1.2.4 - mkl-service: 2.4.0 - monotonic: 1.6 - mpmath: 1.3.0 - multidict: 6.0.5 - munkres: 1.1.4 - mypy-extensions: 1.0.0 - nbclassic: 1.0.0 - nbclient: 0.9.0 - nbconvert: 7.16.1 - nbformat: 5.9.2 - nest-asyncio: 1.6.0 - networkx: 3.1 - notebook: 7.1.1 - notebook-shim: 0.2.4 - numexpr: 2.8.7 - numpy: 1.26.2 - omegaconf: 2.3.0 - overrides: 7.7.0 - packaging: 23.1 - pandas: 2.1.1 - pandocfilters: 1.5.1 - parso: 0.8.3 - pathspec: 0.12.1 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 10.0.1 - pip: 23.3.1 - pipreqs: 0.4.13 - platformdirs: 4.1.0 - ply: 3.11 - pretty-errors: 1.2.25 - prometheus-client: 0.20.0 - prompt-toolkit: 3.0.42 - protobuf: 3.20.3 - psutil: 5.9.1 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - py3nvml: 0.2.7 - pyccolo: 0.0.52 - pycparser: 2.21 - pydantic: 2.5.2 - pydantic-core: 2.14.5 - pygments: 2.17.2 - pyopenssl: 23.2.0 - pyparsing: 3.0.9 - pyqt5: 5.15.10 - pyqt5-sip: 12.13.0 - pysocks: 1.7.1 - python-dateutil: 2.8.2 - python-json-logger: 2.0.7 - pytorch-lightning: 2.2.2 - pytz: 2023.3.post1 - pyyaml: 6.0.1 - pyzmq: 25.1.2 - qtconsole: 5.5.1 - qtpy: 2.4.1 - referencing: 0.30.2 - requests: 2.25.1 - restrictedpython: 7.0 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - rich: 13.7.1 - rpds-py: 0.10.6 - s3transfer: 0.10.0 - scikit-learn: 1.4.1.post1 - scipy: 1.12.0 - seaborn: 0.12.2 - segment-analytics-python: 2.2.3 - send2trash: 1.8.2 - setuptools: 68.0.0 - sip: 6.7.12 - six: 1.16.0 - sniffio: 1.3.0 - soupsieve: 2.5 - sqlalchemy: 1.4.50 - stack-data: 0.6.2 - starlette: 0.27.0 - sympy: 1.12 - tensorboard: 2.17.0 - tensorboard-data-server: 0.7.0 - tensorboardx: 2.6.2.2 - terminado: 0.18.0 - threadpoolctl: 3.4.0 - tinycss2: 1.2.1 - tomli: 2.0.1 - torch: 2.1.1 - torchattacks: 3.5.1 - torchaudio: 2.1.1 - torchmetrics: 1.4.0 - torchvision: 0.16.1 - tornado: 6.3.3 - tqdm: 4.65.0 - traitlets: 5.14.1 - triton: 2.1.0 - types-python-dateutil: 2.8.19.20240106 - typeshed-client: 2.5.1 - typing-extensions: 4.9.0 - tzdata: 2023.3 - uri-template: 1.3.0 - urllib3: 1.26.18 - uvicorn: 0.24.0.post1 - validators: 0.18.2 - wcwidth: 0.2.13 - webcolors: 1.13 - webencodings: 0.5.1 - websocket-client: 1.7.0 - websockets: 12.0 - werkzeug: 3.0.3 - wheel: 0.41.2 - widgetsnbextension: 4.0.10 - xmltodict: 0.13.0 - yarg: 0.1.9 - yarl: 1.9.4 - zipp: 3.11.0
  • System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.9.18 - release: 6.2.0-39-generic - version: #40-Ubuntu SMP PREEMPT_DYNAMIC Tue Nov 14 14:18:00 UTC 2023

More info

No response

matrix72c avatar Aug 08 '24 11:08 matrix72c

I found if I forward acc only in the validate step and test step, the results are the same. However, when using one acc metric in both validate and train, the validate step result is different from single calling validate().

matrix72c avatar Aug 20 '24 04:08 matrix72c

It is common pitfall to reuse the same metric instance for different dataloaders e.g. for both training and validation. See: https://lightning.ai/docs/torchmetrics/stable/pages/lightning.html#common-pitfalls Not using separate metrics for the different stages will mean that the metric will be updated with both data from training batches and validation batches during fit resulting in a wrong metric value. This is also why it seems to work when calling validate and test because there are only one dataloader being used for these methods.

Here is a corrected version of the provided example with a metric for each stage:

class ResNet(L.LightningModule):
    def __init__(
        self,
        num_classes: int,
        use_pretrained: bool = True,
        lr: float = 1e-3,
        step_size: int = 10,
        gamma: float = 0.1,
    ):
        super().__init__()
        self.save_hyperparameters()

        self.model = torchvision.models.resnet50(
            weights=(
                torchvision.models.ResNet50_Weights.DEFAULT if use_pretrained else None
            ),
        )
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
        self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
        self.val_acc = self.train_acc.clone()
        self.test_acc = self.train_acc.clone()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=self.hparams.step_size, gamma=self.hparams.gamma
        )
        return [optimizer], [scheduler]

    def forward(self, x):
        x = self.model(x)
        return x

    def shared_step(self, batch, stage):
        img, label, _ = batch
        logits = self(img)
        loss = F.cross_entropy(logits, label)
        getattr(self, f"{stage}_acc")(logits, labels)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch, stage="train")
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", self.train_acc, prog_bar=True, on_step=True, on_epoch=False)
        return loss

    def validation_step(self, batch, batch_idx):
        _ = self.shared_step(batch, stage="val")
        self.log("val_acc", self.val_acc, prog_bar=True, on_epoch=True, on_step=True)

    def test_step(self, batch, batch_idx):
        _ = self.shared_step(batch, stage="test")
        self.log("test_acc", self.test_acc, prog_bar=True, on_epoch=True, on_step=True)

Closing issue.

SkafteNicki avatar Sep 02 '25 08:09 SkafteNicki