Fit loop and validation loop tear down does not dereference passed dataloaders
Bug description
When a dataloader is passed to pl.Trainer's fit, the trainer does not seem to clear all references to the passed dataloader. As a result, pickling the trainer will pickle the whole dataloader, which isn't ideal for large dataset.
Steps to reporoduce
Run the MRE below.
Expected Behaviour
The trainer's size after fitting should not increase by 1MB (the size of the dataset).
Actual Behaviour
The trainer's size after fitting increased by 1MB (the size of the dataset).
What version are you seeing the problem on?
master
How to reproduce the bug
from pickle import dumps
import lightning.pytorch as pl
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
import torch
model = BoringModel()
dataset = RandomDataset(32, 8192)
data_size = len(dumps(dataset))
trainer = pl.Trainer(fast_dev_run=True)
prev_size = len(dumps(trainer))
trainer.fit(model, train_dataloaders=torch.utils.data.DataLoader(dataset))
new_size = len(dumps(trainer))
size_diff_mb = (new_size - prev_size) // (1024 ** 2)
data_size_mb = data_size // (1024 ** 2)
print(size_diff_mb == data_size_mb)
assert size_diff_mb == 0
Error messages and logs
No response
Environment
Current environment
- CUDA: - GPU: None - available: False - version: None
- Lightning: - lightning: 2.0.6 - lightning-cloud: 0.5.37 - lightning-utilities: 0.9.0 - pytorch-lightning: 2.0.6 - torch: 2.0.1+cpu - torch-model-archiver: 0.8.1 - torch-workflow-archiver: 0.2.9 - torchmetrics: 0.11.4 - torchserve: 0.8.1 - torchvision: 0.15.2a0
- Packages: - absl-py: 1.4.0 - anyio: 3.7.1 - arrow: 1.2.3 - astunparse: 1.6.3 - attrs: 23.1.0 - backoff: 2.2.1 - backports.cached-property: 1.0.2 - backports.functools-lru-cache: 1.6.5 - beautifulsoup4: 4.12.2 - blessed: 1.19.1 - blinker: 1.4 - brotlipy: 0.7.0 - build: 0.10.0 - cachecontrol: 0.12.11 - cachetools: 4.2.2 - captum: 0.6.0 - certifi: 2023.7.22 - cffi: 1.15.0 - charset-normalizer: 3.2.0 - cleo: 2.0.1 - click: 8.1.6 - colorama: 0.4.6 - contourpy: 1.1.0 - crashtest: 0.4.1 - croniter: 1.4.1 - cryptography: 41.0.2 - cycler: 0.11.0 - dateutils: 0.6.12 - deepdiff: 5.8.1 - distlib: 0.3.7 - docker: 6.1.3 - docstring-parser: 0.15 - dulwich: 0.21.3 - enum-compat: 0.0.3 - exceptiongroup: 1.1.2 - fastapi: 0.100.1 - filelock: 3.12.2 - flatbuffers: 2.0 - fonttools: 4.42.0 - fsspec: 2023.6.0 - gast: 0.4.0 - ghp-import: 2.1.0 - google-api-core: 2.11.1 - google-auth: 1.21.3 - google-auth-oauthlib: 0.5.2 - google-cloud-core: 2.3.3 - google-cloud-storage: 2.10.0 - google-crc32c: 1.5.0 - google-pasta: 0.2.0 - google-resumable-media: 2.5.0 - googleapis-common-protos: 1.59.1 - griffe: 0.32.3 - grpcio: 1.48.2 - h11: 0.14.0 - h5py: 3.7.0 - html5lib: 1.1 - idna: 3.4 - importlib-metadata: 6.8.0 - importlib-resources: 6.0.0 - iniconfig: 2.0.0 - inquirer: 3.1.3 - installer: 0.7.0 - itsdangerous: 2.1.2 - jaraco.classes: 3.3.0 - jeepney: 0.8.0 - jinja2: 3.1.2 - joblib: 1.3.1 - jsonschema: 4.17.3 - keras: 2.12.0 - keras-pickle-wrapper: 1.0.5 - keras-preprocessing: 1.1.2 - keyring: 23.13.1 - kfp: 2.0.1 - kfp-pipeline-spec: 0.2.2 - kfp-server-api: 2.0.0 - kiwisolver: 1.4.4 - kubernetes: 26.1.0 - lightning: 2.0.6 - lightning-cloud: 0.5.37 - lightning-utilities: 0.9.0 - lockfile: 0.12.2 - markdown: 3.4.1 - markdown-it-py: 3.0.0 - markupsafe: 2.1.1 - material: 0.1 - materialx: 0.0.0.dev1 - matplotlib: 3.7.2 - mdurl: 0.1.0 - mergedeep: 1.3.4 - minio: 7.1.15 - mkdocs: 1.5.2 - mkdocs-autorefs: 0.5.0 - mkdocs-material: 9.1.21 - mkdocs-material-extensions: 1.1.1 - mkdocstrings: 0.22.0 - mkdocstrings-python: 1.3.0 - mlframework: 2.2.1 - more-itertools: 10.0.0 - mpmath: 1.3.0 - msgpack: 1.0.3 - networkx: 3.1 - numpy: 1.22.3 - oauthlib: 3.2.2 - opt-einsum: 3.3.0 - ordered-set: 4.1.0 - packaging: 23.1 - pandas: 2.0.3 - pathspec: 0.11.2 - pexpect: 4.8.0 - pillow: 9.4.0 - pip: 23.2.1 - pipdeptree: 2.12.0 - pkginfo: 1.9.6 - pkgutil-resolve-name: 1.3.10 - platformdirs: 3.10.0 - pluggy: 1.2.0 - poetry: 1.5.1 - poetry-core: 1.6.1 - poetry-plugin-export: 1.4.0 - protobuf: 3.20.3 - psutil: 5.9.0 - psycopg2: 2.9.6 - ptyprocess: 0.7.0 - pyasn1: 0.4.8 - pyasn1-modules: 0.2.8 - pycparser: 2.21 - pydantic: 1.10.8 - pygments: 2.15.1 - pyjwt: 2.8.0 - pymdown-extensions: 10.1 - pympler: 1.0.1 - pyopenssl: 23.2.0 - pyparsing: 3.0.9 - pyproject-hooks: 1.0.0 - pyrsistent: 0.18.0 - pysocks: 1.7.1 - pytest: 7.4.0 - python-dateutil: 2.8.2 - python-editor: 1.0.4 - python-multipart: 0.0.6 - pytorch-lightning: 2.0.6 - pytz: 2023.3 - pyyaml: 6.0.1 - pyyaml-env-tag: 0.1 - rapidfuzz: 2.13.7 - readchar: 4.0.5.dev0 - regex: 2023.8.8 - requests: 2.31.0 - requests-oauthlib: 1.3.0 - requests-toolbelt: 0.9.1 - rich: 13.5.1 - rsa: 4.7.2 - scikit-learn: 1.3.0 - scipy: 1.11.1 - secretstorage: 3.3.3 - setuptools: 68.0.0 - shellingham: 1.5.1 - six: 1.16.0 - sniffio: 1.3.0 - soupsieve: 2.3.2.post1 - starlette: 0.27.0 - starsessions: 1.3.0 - sympy: 1.12 - tabulate: 0.9.0 - tenacity: 8.2.2 - tensorboard: 2.12.1 - tensorboard-data-server: 0.7.0 - tensorboard-plugin-wit: 1.8.1 - tensorflow: 2.12.0 - tensorflow-estimator: 2.12.0 - termcolor: 2.1.0 - threadpoolctl: 3.2.0 - tomli: 2.0.1 - tomlkit: 0.12.1 - torch: 2.0.1+cpu - torch-model-archiver: 0.8.1 - torch-workflow-archiver: 0.2.9 - torchmetrics: 0.11.4 - torchserve: 0.8.1 - torchvision: 0.15.2a0 - tqdm: 4.65.0 - traitlets: 5.9.0 - trove-classifiers: 2023.7.6 - typeguard: 4.0.0 - typing-extensions: 4.7.1 - tzdata: 2023.3 - urllib3: 1.26.15 - uvicorn: 0.23.2 - virtualenv: 20.24.2 - watchdog: 3.0.0 - wcwidth: 0.2.6 - webencodings: 0.5.1 - websocket-client: 1.6.1 - websockets: 10.4 - werkzeug: 2.2.3 - wheel: 0.38.4 - wrapt: 1.14.1 - zipp: 3.16.2
- System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.12 - release: 5.4.0-1106-gcp - version: #115~18.04.1-Ubuntu SMP Mon May 22 20:46:39 UTC 202 3
More info
Possible Fix
Set trainer.fit_loop._data_source.instance = None and trainer.fit_loop._combined_loader = None in the loop's teardown. These two variables seem to be holding on to the dataloader's reference. I can create a PR that applies this fix.
Workaround
- Set the above variables to
Nonemanually just before pickling the trainer. - Since this issue doesn't happen if the dataloader is passed by overriding
train_dataloader, don't pass the dataloader tofit.
cc @justusschock @awaelchli @borda
Thanks @gov-ind for investigating. I think it makes sense to do this. You are welcome to send a PR, thanks for the help!
I would strongly suggest that you don't pickle the Trainer object. This is a bad idea as you are pickling the precise code and imports, which might break with future changes.
Since dereferencing the dataloader after training finishes is a breaking change (the tests in your PR showed this), my suggestion is that we don't change anything, since you are still able to dereference manually. This is also your suggestion in point (1): https://github.com/Lightning-AI/lightning/pull/18293/files#r1307301711
However, you could add a test that demonstrates your use case so that it's considered in future changes.