pytorch-lightning
pytorch-lightning copied to clipboard
trainer.validate() get different result from trainer.fit
Bug description
I'm training a ResNet50 model and using model checkpoint(only weight) to save the best model. However I found the results are different in fit and validate. I start training with lightning clli and set run=False, just manually call fit and validate.
What version are you seeing the problem on?
v2.2
How to reproduce the bug
Main code
import lightning as L
import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from torchmetrics import Accuracy
from lightning.pytorch.cli import LightningCLI
class ResNet(L.LightningModule):
def __init__(
self,
num_classes: int,
use_pretrained: bool = True,
lr: float = 1e-3,
step_size: int = 10,
gamma: float = 0.1,
):
super().__init__()
self.save_hyperparameters()
self.model = torchvision.models.resnet50(
weights=(
torchvision.models.ResNet50_Weights.DEFAULT if use_pretrained else None
),
)
self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
self.acc = Accuracy(task="multiclass", num_classes=num_classes)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=self.hparams.step_size, gamma=self.hparams.gamma
)
return [optimizer], [scheduler]
def forward(self, x):
x = self.model(x)
return x
def shared_step(self, batch):
img, label, _ = batch
logits = self(img)
loss = F.cross_entropy(logits, label)
self.acc(logits, label)
return loss
def training_step(self, batch, batch_idx):
loss = self.shared_step(batch)
self.log("train_loss", loss, prog_bar=True)
self.log("train_acc", self.acc, prog_bar=True, on_step=True, on_epoch=False)
return loss
def validation_step(self, batch, batch_idx):
_ = self.shared_step(batch)
self.log("val_acc", self.acc, prog_bar=True, on_epoch=True, on_step=True)
def test_step(self, batch, batch_idx):
_ = self.shared_step(batch)
self.log("test_acc", self.acc, prog_bar=True, on_epoch=True, on_step=True)
if __name__ == "__main__":
torch.set_float32_matmul_precision("high")
cli = LightningCLI(save_config_callback=None, run=False)
cli.trainer.fit(cli.model, cli.datamodule)
cli.trainer.validate(cli.model, cli.datamodule, ckpt_path="checkpoint/resnet_CUB.ckpt")
config file
seed_everything: 42
trainer:
max_epochs: 1000
log_every_n_steps: 1
logger:
class_path: aim.pytorch_lightning.AimLogger
init_args:
run_name: resnet_CUB
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: val_acc
dirpath: checkpoints/
filename: resnet_CUB
save_top_k: 1
mode: max
save_weights_only: True
enable_version_counter: False
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
monitor: val_acc
mode: max
min_delta: 0.001
patience: 0
model:
class_path: model.ResNet
init_args:
num_classes: 200
data:
class_path: dataset.CUB
init_args:
data_path: ./data
batch_size: 128
Error messages and logs
In fit, the final acc_epoch is 0.86, while in trainer.validate, it becomes 0.74.
Environment
Current environment
- CUDA: - GPU: - NVIDIA GeForce RTX 4090 - available: True - version: 11.8
- Lightning: - lightning: 2.2.5 - lightning-utilities: 0.11.2 - pytorch-lightning: 2.2.2 - torch: 2.1.1 - torchattacks: 3.5.1 - torchaudio: 2.1.1 - torchmetrics: 1.4.0 - torchvision: 0.16.1
- Packages: - absl-py: 2.1.0 - aim: 3.22.0 - aim-ui: 3.22.0 - aimrecords: 0.0.7 - aimrocks: 0.5.2 - aiofiles: 23.2.1 - aiohttp: 3.9.5 - aiosignal: 1.3.1 - alembic: 1.13.0 - annotated-types: 0.6.0 - antlr4-python3-runtime: 4.9.3 - anyio: 3.7.1 - argcomplete: 3.4.0 - argon2-cffi: 23.1.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.3.0 - asttokens: 2.4.1 - astunparse: 1.6.3 - async-lru: 2.0.4 - async-timeout: 4.0.3 - attrs: 23.1.0 - babel: 2.14.0 - backoff: 2.2.1 - base58: 2.0.1 - beautifulsoup4: 4.12.3 - bitsandbytes: 0.41.0 - black: 24.2.0 - bleach: 6.1.0 - boto3: 1.34.62 - botocore: 1.34.62 - bottleneck: 1.3.5 - brotli: 1.0.9 - cachetools: 5.3.2 - certifi: 2024.7.4 - cffi: 1.16.0 - chardet: 4.0.0 - charset-normalizer: 2.0.4 - click: 8.1.7 - colorama: 0.4.6 - comm: 0.2.1 - contourpy: 1.2.0 - cryptography: 41.0.7 - cycler: 0.11.0 - debugpy: 1.6.7 - decorator: 5.1.1 - defusedxml: 0.7.1 - docopt: 0.6.2 - docstring-parser: 0.16 - exceptiongroup: 1.2.0 - executing: 2.0.1 - fastapi: 0.104.1 - fastjsonschema: 2.19.1 - filelock: 3.13.1 - fonttools: 4.25.0 - fqdn: 1.5.1 - frozenlist: 1.4.1 - fsspec: 2023.10.0 - gmpy2: 2.1.2 - greenlet: 3.0.2 - grpcio: 1.48.2 - h11: 0.14.0 - httpcore: 1.0.4 - httpx: 0.27.0 - hydra-core: 1.3.2 - idna: 2.10 - importlib-metadata: 7.0.1 - importlib-resources: 6.1.0 - ipyflow: 0.0.198 - ipyflow-core: 0.0.198 - ipykernel: 6.29.0 - ipython: 8.18.1 - ipython-genutils: 0.2.0 - ipywidgets: 8.1.2 - isoduration: 20.11.0 - jedi: 0.19.1 - jinja2: 3.1.2 - jmespath: 1.0.1 - joblib: 1.3.2 - json5: 0.9.17 - jsonargparse: 4.29.0 - jsonnet: 0.17.0 - jsonpointer: 2.4 - jsonschema: 4.19.2 - jsonschema-specifications: 2023.12.1 - jupyter: 1.0.0 - jupyter-client: 8.6.0 - jupyter-console: 6.6.3 - jupyter-core: 5.7.1 - jupyter-events: 0.9.0 - jupyter-lsp: 2.2.3 - jupyter-server: 2.12.5 - jupyter-server-terminals: 0.5.2 - jupyterlab: 4.1.2 - jupyterlab-pygments: 0.3.0 - jupyterlab-server: 2.25.3 - jupyterlab-widgets: 3.0.10 - kiwisolver: 1.4.4 - lightning: 2.2.5 - lightning-utilities: 0.11.2 - mako: 1.3.0 - markdown: 3.6 - markdown-it-py: 3.0.0 - markupsafe: 2.1.1 - matplotlib: 3.8.0 - matplotlib-inline: 0.1.6 - mdurl: 0.1.2 - mistune: 3.0.2 - mkl-fft: 1.3.8 - mkl-random: 1.2.4 - mkl-service: 2.4.0 - monotonic: 1.6 - mpmath: 1.3.0 - multidict: 6.0.5 - munkres: 1.1.4 - mypy-extensions: 1.0.0 - nbclassic: 1.0.0 - nbclient: 0.9.0 - nbconvert: 7.16.1 - nbformat: 5.9.2 - nest-asyncio: 1.6.0 - networkx: 3.1 - notebook: 7.1.1 - notebook-shim: 0.2.4 - numexpr: 2.8.7 - numpy: 1.26.2 - omegaconf: 2.3.0 - overrides: 7.7.0 - packaging: 23.1 - pandas: 2.1.1 - pandocfilters: 1.5.1 - parso: 0.8.3 - pathspec: 0.12.1 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 10.0.1 - pip: 23.3.1 - pipreqs: 0.4.13 - platformdirs: 4.1.0 - ply: 3.11 - pretty-errors: 1.2.25 - prometheus-client: 0.20.0 - prompt-toolkit: 3.0.42 - protobuf: 3.20.3 - psutil: 5.9.1 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - py3nvml: 0.2.7 - pyccolo: 0.0.52 - pycparser: 2.21 - pydantic: 2.5.2 - pydantic-core: 2.14.5 - pygments: 2.17.2 - pyopenssl: 23.2.0 - pyparsing: 3.0.9 - pyqt5: 5.15.10 - pyqt5-sip: 12.13.0 - pysocks: 1.7.1 - python-dateutil: 2.8.2 - python-json-logger: 2.0.7 - pytorch-lightning: 2.2.2 - pytz: 2023.3.post1 - pyyaml: 6.0.1 - pyzmq: 25.1.2 - qtconsole: 5.5.1 - qtpy: 2.4.1 - referencing: 0.30.2 - requests: 2.25.1 - restrictedpython: 7.0 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - rich: 13.7.1 - rpds-py: 0.10.6 - s3transfer: 0.10.0 - scikit-learn: 1.4.1.post1 - scipy: 1.12.0 - seaborn: 0.12.2 - segment-analytics-python: 2.2.3 - send2trash: 1.8.2 - setuptools: 68.0.0 - sip: 6.7.12 - six: 1.16.0 - sniffio: 1.3.0 - soupsieve: 2.5 - sqlalchemy: 1.4.50 - stack-data: 0.6.2 - starlette: 0.27.0 - sympy: 1.12 - tensorboard: 2.17.0 - tensorboard-data-server: 0.7.0 - tensorboardx: 2.6.2.2 - terminado: 0.18.0 - threadpoolctl: 3.4.0 - tinycss2: 1.2.1 - tomli: 2.0.1 - torch: 2.1.1 - torchattacks: 3.5.1 - torchaudio: 2.1.1 - torchmetrics: 1.4.0 - torchvision: 0.16.1 - tornado: 6.3.3 - tqdm: 4.65.0 - traitlets: 5.14.1 - triton: 2.1.0 - types-python-dateutil: 2.8.19.20240106 - typeshed-client: 2.5.1 - typing-extensions: 4.9.0 - tzdata: 2023.3 - uri-template: 1.3.0 - urllib3: 1.26.18 - uvicorn: 0.24.0.post1 - validators: 0.18.2 - wcwidth: 0.2.13 - webcolors: 1.13 - webencodings: 0.5.1 - websocket-client: 1.7.0 - websockets: 12.0 - werkzeug: 3.0.3 - wheel: 0.41.2 - widgetsnbextension: 4.0.10 - xmltodict: 0.13.0 - yarg: 0.1.9 - yarl: 1.9.4 - zipp: 3.11.0
- System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.9.18 - release: 6.2.0-39-generic - version: #40-Ubuntu SMP PREEMPT_DYNAMIC Tue Nov 14 14:18:00 UTC 2023
More info
No response
I found if I forward acc only in the validate step and test step, the results are the same. However, when using one acc metric in both validate and train, the validate step result is different from single calling validate().
It is common pitfall to reuse the same metric instance for different dataloaders e.g. for both training and validation. See:
https://lightning.ai/docs/torchmetrics/stable/pages/lightning.html#common-pitfalls
Not using separate metrics for the different stages will mean that the metric will be updated with both data from training batches and validation batches during fit resulting in a wrong metric value. This is also why it seems to work when calling validate and test because there are only one dataloader being used for these methods.
Here is a corrected version of the provided example with a metric for each stage:
class ResNet(L.LightningModule):
def __init__(
self,
num_classes: int,
use_pretrained: bool = True,
lr: float = 1e-3,
step_size: int = 10,
gamma: float = 0.1,
):
super().__init__()
self.save_hyperparameters()
self.model = torchvision.models.resnet50(
weights=(
torchvision.models.ResNet50_Weights.DEFAULT if use_pretrained else None
),
)
self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
self.val_acc = self.train_acc.clone()
self.test_acc = self.train_acc.clone()
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=self.hparams.step_size, gamma=self.hparams.gamma
)
return [optimizer], [scheduler]
def forward(self, x):
x = self.model(x)
return x
def shared_step(self, batch, stage):
img, label, _ = batch
logits = self(img)
loss = F.cross_entropy(logits, label)
getattr(self, f"{stage}_acc")(logits, labels)
return loss
def training_step(self, batch, batch_idx):
loss = self.shared_step(batch, stage="train")
self.log("train_loss", loss, prog_bar=True)
self.log("train_acc", self.train_acc, prog_bar=True, on_step=True, on_epoch=False)
return loss
def validation_step(self, batch, batch_idx):
_ = self.shared_step(batch, stage="val")
self.log("val_acc", self.val_acc, prog_bar=True, on_epoch=True, on_step=True)
def test_step(self, batch, batch_idx):
_ = self.shared_step(batch, stage="test")
self.log("test_acc", self.test_acc, prog_bar=True, on_epoch=True, on_step=True)
Closing issue.