pytorch-lightning
pytorch-lightning copied to clipboard
Trainer.predict and Trainer.test reset model state to evaluation
Bug description
I want to have dropout active in evaluation mode to generate random outputs. For this, I set the model to train mode before performing the prediction. However, using lightning's Trainer.predict resets the model to evaluation mode, essentially disabling dropout, which leads to deterministic outputs. Running the prediction on the raw model works as expected.
Note: Contrary to the version dropdown selection, I am running version 2.6.0.
What version are you seeing the problem on?
master
Reproduced in studio
No response
How to reproduce the bug
# %%
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader, TensorDataset
# %%
class SimpleModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(10, 1)
self.dropout = torch.nn.Dropout(0.5)
def forward(self, x):
return self.dropout(self.layer(x))
def test_step(self, batch, batch_idx):
x, y = batch
out = self(x)
loss = torch.nn.functional.mse_loss(out, y)
self.log("test_loss", loss)
return loss
def predict_step(self, batch, batch_idx):
x, y = batch
return self(x)
# Data
X = torch.randn(100, 10)
y = torch.randn(100, 1)
ds = TensorDataset(X, y)
dl = DataLoader(ds, batch_size=10)
# Model
pl.seed_everything(42)
model = SimpleModel()
trainer = pl.Trainer(accelerator="cpu", devices=1)
# %%
print("--- Evaluate model using Trainer ---")
print("--- Run 1: train(False) ---")
model.train(False)
predictions = trainer.predict(model, dataloaders=dl)
y_pred_deterministic = torch.cat(predictions)
print("--- Run 2: train(True) ---")
model.train(True)
predictions = trainer.predict(model, dataloaders=dl)
y_pred_stochastic = torch.cat(predictions)
are_same = torch.allclose(y_pred_deterministic, y_pred_stochastic)
print(f"Predictions are the same: {are_same}")
print("------------------------------------")
# %%
print("--- Evaluate model using Raw Model Loop ---")
print("--- Run 3: raw train(True) ---")
model.train(False)
with torch.no_grad():
raw_predictions = [model(x) for x, _ in dl]
y_pred_raw_deterministic = torch.cat(raw_predictions)
print("--- Run 4: raw train(True) ---")
model.train(True)
with torch.no_grad():
raw_predictions = [model(x) for x, _ in dl]
y_pred_raw_stochastic = torch.cat(raw_predictions)
are_same = torch.allclose(y_pred_raw_deterministic, y_pred_raw_stochastic)
print(f"Raw predictions are the same: {are_same}")
# %%
Error messages and logs
# Error messages and logs here please
Environment
Current environment
- CUDA:
- GPU: None
- available: False
- version: 12.8
- Lightning:
- lightning-utilities: 0.15.2
- pytorch-lightning: 2.6.0
- torch: 2.9.1
- torchinfo: 1.8.0
- torchmetrics: 1.8.2
- torchview: 0.2.7
- Packages:
- absl-py: 2.3.1
- aiohappyeyeballs: 2.6.1
- aiohttp: 3.13.2
- aiosignal: 1.4.0
- alembic: 1.17.2
- alibi-detect: 0.13.0
- altair: 6.0.0
- annotated-doc: 0.0.4
- annotated-types: 0.7.0
- anyio: 4.12.0
- appdirs: 1.4.4
- argon2-cffi: 25.1.0
- argon2-cffi-bindings: 25.1.0
- arrow: 1.4.0
- astor: 0.8.1
- asttokens: 3.0.1
- astunparse: 1.6.3
- async-lru: 2.0.5
- attrs: 25.4.0
- autocommand: 2.2.2
- babel: 2.17.0
- backports-zstd: 1.2.0
- backports.tarfile: 1.2.0
- beautifulsoup4: 4.14.3
- bleach: 6.3.0
- blinker: 1.9.0
- bokeh: 3.8.1
- brotli: 1.2.0
- cachetools: 6.2.2
- cartes: 0.8.5
- cartopy: 0.25.0
- catalogue: 2.0.10
- certifi: 2025.11.12
- cffi: 2.0.0
- charset-normalizer: 3.4.4
- cheroot: 11.1.2
- click: 8.2.1
- cloudpickle: 3.1.2
- cmdstanpy: 1.3.0
- coloredlogs: 15.0.1
- comm: 0.2.3
- contourpy: 1.3.3
- cramjam: 2.11.0
- cryptography: 46.0.3
- cycler: 0.12.1
- databricks-sdk: 0.73.0
- debugpy: 1.8.17
- decorator: 5.2.1
- defusedxml: 0.7.1
- dill: 0.3.9
- dm-tree: 0.1.9
- docker: 7.1.0
- etils: 1.13.0
- executing: 2.2.1
- fastapi: 0.124.0
- fastjsonschema: 2.21.2
- fastparquet: 2024.11.0
- filelock: 3.20.0
- flask: 3.1.2
- flask-cors: 6.0.1
- flatbuffers: 25.9.23
- flexcache: 0.3
- flexparser: 0.4
- fonttools: 4.61.0
- fqdn: 1.5.1
- frozenlist: 1.8.0
- fsspec: 2025.12.0
- gast: 0.7.0
- gcsfs: 2025.12.0
- geopandas: 1.1.1
- gitdb: 4.0.12
- gitpython: 3.1.45
- google-api-core: 2.28.1
- google-auth: 2.43.0
- google-auth-oauthlib: 1.2.2
- google-cloud-core: 2.5.0
- google-cloud-storage: 3.6.0
- google-cloud-storage-control: 1.8.0
- google-crc32c: 1.7.1
- google-pasta: 0.2.0
- google-resumable-media: 2.8.0
- googleapis-common-protos: 1.72.0
- graphene: 3.4.3
- graphql-core: 3.2.7
- graphql-relay: 3.2.0
- graphviz: 0.21
- greenlet: 3.3.0
- grpc-google-iam-v1: 0.14.3
- grpcio: 1.76.0
- grpcio-status: 1.76.0
- gunicorn: 23.0.0
- gviz-api: 1.10.0
- h11: 0.16.0
- h2: 4.3.0
- h5py: 3.15.1
- hf-xet: 1.2.0
- holidays: 0.86
- hpack: 4.1.0
- httpcore: 1.0.9
- httpx: 0.28.1
- huggingface-hub: 0.36.0
- humanfriendly: 10.0
- hyperframe: 6.1.0
- idna: 3.11
- imageio: 2.37.2
- importlib-metadata: 8.7.0
- importlib-resources: 6.5.2
- impunity: 1.0.5
- inflate64: 1.0.4
- inflect: 7.3.1
- iniconfig: 2.3.0
- ipykernel: 7.1.0
- ipython: 9.8.0
- ipython-pygments-lexers: 1.1.1
- ipywidgets: 8.1.8
- isoduration: 20.11.0
- itsdangerous: 2.2.0
- jaraco-functools: 4.3.0
- jaraco.classes: 3.4.0
- jaraco.collections: 5.1.0
- jaraco.context: 6.0.1
- jaraco.functools: 4.0.1
- jaraco.text: 3.12.1
- jedi: 0.19.2
- jeepney: 0.9.0
- jinja2: 3.1.6
- joblib: 1.5.2
- json5: 0.12.1
- jsonpointer: 3.0.0
- jsonschema: 4.25.1
- jsonschema-specifications: 2025.9.1
- jupyter: 1.1.1
- jupyter-client: 8.6.3
- jupyter-console: 6.6.3
- jupyter-core: 5.9.1
- jupyter-events: 0.12.0
- jupyter-lsp: 2.3.0
- jupyter-server: 2.17.0
- jupyter-server-terminals: 0.5.3
- jupyterlab: 4.5.0
- jupyterlab-pygments: 0.3.0
- jupyterlab-server: 2.28.0
- jupyterlab-widgets: 3.0.16
- keopscore: 2.2.3
- keras: 3.12.0
- keras-tuner: 1.4.8
- keyring: 25.7.0
- kiwisolver: 1.4.9
- kt-legacy: 1.0.5
- lark: 1.3.1
- lazy-loader: 0.4
- libclang: 18.1.1
- librt: 0.7.3
- lightning-utilities: 0.15.2
- llvmlite: 0.46.0
- lxml: 6.0.2
- lz4: 4.4.5
- mako: 1.3.10
- markdown: 3.10
- markdown-it-py: 4.0.0
- markupsafe: 3.0.3
- matplotlib: 3.10.7
- matplotlib-inline: 0.2.1
- mdurl: 0.1.2
- metar: 1.11.0
- minio: 7.2.20
- mistune: 3.1.4
- ml-dtypes: 0.5.4
- mlflow: 3.5.1
- mlflow-skinny: 3.5.1
- mlflow-tracing: 3.5.1
- more-itertools: 10.8.0
- mpmath: 1.3.0
- msgpack: 1.1.2
- multidict: 6.7.0
- multivolumefile: 0.2.3
- mypy: 1.19.0
- mypy-extensions: 1.1.0
- namex: 0.1.0
- narwhals: 2.13.0
- nbclient: 0.10.2
- nbconvert: 7.16.6
- nbformat: 5.10.4
- nest-asyncio: 1.6.0
- networkx: 3.6.1
- notebook: 7.5.0
- notebook-shim: 0.2.4
- numba: 0.63.0
- numpy: 2.3.5
- nvidia-cublas-cu12: 12.8.4.1
- nvidia-cuda-cupti-cu12: 12.8.90
- nvidia-cuda-nvcc-cu12: 12.9.86
- nvidia-cuda-nvrtc-cu12: 12.8.93
- nvidia-cuda-runtime-cu12: 12.8.90
- nvidia-cudnn-cu12: 9.10.2.21
- nvidia-cufft-cu12: 11.3.3.83
- nvidia-cufile-cu12: 1.13.1.3
- nvidia-curand-cu12: 10.3.9.90
- nvidia-cusolver-cu12: 11.7.3.90
- nvidia-cusparse-cu12: 12.5.8.93
- nvidia-cusparselt-cu12: 0.7.1
- nvidia-nccl-cu12: 2.27.5
- nvidia-nvjitlink-cu12: 12.8.93
- nvidia-nvshmem-cu12: 3.3.20
- nvidia-nvtx-cu12: 12.8.90
- oauthlib: 3.3.1
- onnxruntime: 1.23.2
- openap: 2.4
- opencv-python: 4.11.0.86
- opentelemetry-api: 1.39.0
- opentelemetry-proto: 1.39.0
- opentelemetry-sdk: 1.39.0
- opentelemetry-semantic-conventions: 0.60b0
- opt-einsum: 3.4.0
- optree: 0.18.0
- orjson: 3.11.5
- overrides: 7.7.0
- packaging: 25.0
- pandas: 2.3.3
- pandocfilters: 1.5.1
- parso: 0.8.5
- pathspec: 0.12.1
- patsy: 1.0.2
- pexpect: 4.9.0
- pillow: 10.4.0
- pint: 0.25.2
- pitot: 0.3.2
- platformdirs: 4.5.1
- plotly: 6.5.0
- pluggy: 1.6.0
- prometheus-client: 0.23.1
- prompt-toolkit: 3.0.52
- propcache: 0.4.1
- properscoring: 0.1
- prophet: 1.2.1
- proto-plus: 1.26.1
- protobuf: 6.33.2
- psutil: 7.1.3
- ptyprocess: 0.7.0
- pure-eval: 0.2.3
- py7zr: 1.0.0
- pyarrow: 21.0.0
- pyasn1: 0.6.1
- pyasn1-modules: 0.4.2
- pybcj: 1.0.7
- pybind11: 3.0.1
- pycparser: 2.23
- pycryptodome: 3.23.0
- pycryptodomex: 3.23.0
- pydantic: 2.12.5
- pydantic-core: 2.41.5
- pygments: 2.19.2
- pyjwt: 2.10.1
- pykeops: 2.2.3
- pynverse: 0.1.4.6
- pyod: 2.0.6
- pyogrio: 0.12.1
- pyopensky: 2.15
- pyparsing: 3.2.5
- pyppmd: 1.2.0
- pyproj: 3.7.2
- pyshp: 3.0.3
- pytest: 9.0.2
- python-dateutil: 2.9.0.post0
- python-dotenv: 1.2.1
- python-json-logger: 4.0.0
- pytorch-lightning: 2.6.0
- pytz: 2025.2
- pyyaml: 6.0.3
- pyzmq: 27.1.0
- pyzstd: 0.19.0
- quantile-forest: 1.4.1
- ray: 2.52.1
- referencing: 0.37.0
- regex: 2025.11.3
- requests: 2.32.5
- requests-oauthlib: 2.0.0
- rfc3339-validator: 0.1.4
- rfc3986-validator: 0.1.1
- rfc3987-syntax: 1.1.0
- rich: 14.2.0
- rpds-py: 0.30.0
- rs1090: 0.4.14
- rsa: 4.9.1
- ruff: 0.14.8
- safetensors: 0.7.0
- scikit-image: 0.25.2
- scikit-learn: 1.7.2
- scipy: 1.16.3
- seaborn: 0.13.2
- secretstorage: 3.5.0
- send2trash: 1.8.3
- setuptools: 80.9.0
- sfoutils: 0.2.0
- shap: 0.50.0
- shapely: 2.1.2
- six: 1.17.0
- sklearn-quantile: 0.1.1
- slicer: 0.0.8
- smmap: 5.0.2
- soupsieve: 2.8
- sqlalchemy: 2.0.44
- sqlparse: 0.5.4
- stack-data: 0.6.3
- stanio: 0.5.1
- starlette: 0.50.0
- statsmodels: 0.14.6
- sympy: 1.14.0
- tensorboard: 2.20.0
- tensorboard-data-server: 0.7.2
- tensorboard-plugin-profile: 2.21.3
- tensorboardx: 2.6.4
- tensorflow: 2.20.0
- tensorflow-docs: 2025.12.2.70325
- tensorflow-probability: 0.25.0
- termcolor: 3.2.0
- terminado: 0.18.1
- texttable: 1.7.0
- tf-keras: 2.20.1
- threadpoolctl: 3.6.0
- tifffile: 2025.10.16
- tinycss2: 1.4.0
- tokenizers: 0.21.4
- toml: 0.10.2
- tomli: 2.0.1
- torch: 2.9.1
- torchinfo: 1.8.0
- torchmetrics: 1.8.2
- torchview: 0.2.7
- tornado: 6.5.2
- tqdm: 4.67.1
- traffic: 2.13.post16.dev0+9e00a83
- traitlets: 5.14.3
- transformers: 4.51.3
- trino: 0.336.0
- triton: 3.5.1
- tudcolors: 0.0.1
- typeguard: 4.3.0
- types-protobuf: 6.32.1.20251105
- types-requests: 2.32.4.20250913
- types-tensorflow: 2.18.0.20251008
- typing-extensions: 4.15.0
- typing-inspection: 0.4.2
- tzdata: 2025.2
- tzlocal: 5.3.1
- uri-template: 1.3.0
- urllib3: 2.6.1
- uvicorn: 0.38.0
- wcwidth: 0.2.14
- webcolors: 25.10.0
- webencodings: 0.5.1
- websocket-client: 1.9.0
- werkzeug: 3.1.4
- wheel: 0.45.1
- widgetsnbextension: 4.0.15
- wrapt: 2.0.1
- xprof: 2.21.3
- xyzservices: 2025.11.0
- yarl: 1.22.0
- zipp: 3.23.0
- zstandard: 0.25.0
- System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor:
- python: 3.13.9
- release: 6.17.9-200.fc42.x86_64
- version: #1 SMP PREEMPT_DYNAMIC Mon Nov 24 22:28:05 UTC 2025
More info
No response
cc @ethanwharris