Add support to `saving.py` for loading GPU-trained models on CPU-only machines
What does this PR do?
- Adds support to
saving.pyfor loading GPU-trained models on CPU-only machines. - Without this fix, a
.to()call in the context of CPU-only inference may lead toAssertionError: Torch not compiled with CUDA enabled.
Before submitting
- [ ] Was this discussed/agreed via a GitHub issue? (not for typos and docs)
- [x] Did you read the contributor guideline, Pull Request section?
- [x] Did you make sure your PR does only one thing, instead of bundling different changes together?
- [x] Did you make sure to update the documentation with your changes? (if necessary)
- [x] Did you write any new necessary tests? (not for typos and docs)
- [ ] Did you verify new and existing tests pass locally with your changes?
- [x] Did you list all the breaking changes introduced by this pull request?
- [x] Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)
PR review
Anyone in the community is welcome to review the PR. Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist
- [x] Is this pull request ready for review? (if not, please submit in draft mode)
- [x] Check that all items from Before submitting are resolved
- [x] Make sure the title is self-explanatory and the description concisely explains the PR
- [ ] Add labels and milestones (and optionally projects) to the PR so it can be classified
:books: Documentation preview :books:: https://pytorch-lightning--19024.org.readthedocs.build/en/19024/
@amorehead Have you tried reporting this on PyTorch? You would expect that cpu_thing.to(cpu) is always a no-op
@carmocca, great point. I'll open up an issue for PyTorch as well, linked to this one for Lightning. However, for the time being (since it may take a while for PyTorch to fix the issue on their end), I think this PR for Lightning should still be useful for the time being, in case other users run into the same issue I am facing.
Yes, we can merge this, but I would like to hear from their team first before moving forward. Then we could have this:
if not _TORCH_GREATER_EQUAL_2_2:
# your patch
@amorehead Great find. If you still have it, could you provide the full stack trace of the error?
@awaelchli, yes, the stack trace is as follows.
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/home/acmwhb/mambaforge/envs/GCPNet/lib/python3.10/site-packages/lightning/pytorch/core/module.py", line 1552, in load_from_checkpoint
loaded = _load_from_checkpoint(
File "/home/acmwhb/mambaforge/envs/GCPNet/lib/python3.10/site-packages/lightning/pytorch/core/saving.py", line 97, in _load_from_checkpoint
return model.to(device)
File "/home/acmwhb/mambaforge/envs/GCPNet/lib/python3.10/site-packages/lightning/fabric/utilities/device_dtype_mixin.py", line 54, in to
return super().to(*args, **kwargs)
File "/home/acmwhb/mambaforge/envs/GCPNet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1160, in to
return self._apply(convert)
File "/home/acmwhb/mambaforge/envs/GCPNet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 810, in _apply
module._apply(fn)
File "/home/acmwhb/mambaforge/envs/GCPNet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 810, in _apply
module._apply(fn)
File "/home/acmwhb/mambaforge/envs/GCPNet/lib/python3.10/site-packages/torchmetrics/metric.py", line 808, in _apply
self._device = fn(torch.zeros(1, device=self.device)).device
File "/home/acmwhb/mambaforge/envs/GCPNet/lib/python3.10/site-packages/torch/cuda/__init__.py", line 289, in _lazy_init
raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled
This is triggered by calling:
my_lightning_model.__class__.load_from_checkpoint(
checkpoint_path=ckpt_path,
map_location="cpu",
strict=True,
)
The issue happens with both Lightning 2.1.0 and 2.1.2 (note the __class__ bit for 2.1.2). When I install my patched version of Lightning (as packaged in this PR), this issue goes away by skipping these .to() calls altogether.
Given the stack trace, we see that it goes through torchmetrics and fails at this line:
self._device = fn(torch.zeros(1, device=self.device)).device
maybe self.device (for some reason) is cuda and not cpu? In any case, it would be good if we could identify if it's an issue PyTorch or metrics. I couldn't repro on my MacOS. @amorehead any change you could help here sanity checking that self.device is CPU is in this line above?
This seems to be a torchmetrics bug, see discussion on the PyTorch issue tracker (https://github.com/pytorch/pytorch/issues/113973).
@amorehead Did you actually end up with the entire torchmetric object pickled in the checkpoint like described by this user https://github.com/Lightning-AI/torchmetrics/issues/2223 or was it a proper checkpoint with the state dict of the metric? Because the former would indeed explain your issue, but then the fix should be not to pickle the metric in the first place.
@awaelchli,
You have described it perfectly. The checkpoints I am trying to load on a CPU-only machine contain full TorchMetrics objects in them unintentionally. Seems this is not best practice by any means. Are you aware of any workarounds for this issue in light of the metrics being fully saved in my checkpoint files, or is the only solution to only save the state_dicts in the first place?
Hi, I was having the same issue, and this commit fixed it for me! I would be very happy if this gets merged.
Codecov Report
Attention: Patch coverage is 83.33333% with 1 line in your changes missing coverage. Please review.
Project coverage is 47%. Comparing base (
bb14a97) to head (c856888). Report is 339 commits behind head on master.
:exclamation: There is a different number of reports uploaded between BASE (bb14a97) and HEAD (c856888). Click for more details.
HEAD has 205 uploads less than BASE
Flag BASE (bb14a97) HEAD (c856888) lightning 44 15 cpu 74 24 pytest 56 0 python3.10 21 9 app 9 0 examples 9 0 gpu 4 0 lightning_fabric 10 0 python3.9 6 3 python3.11 15 6 python3.8 12 6 tpu 2 0 pytorch_lightning 10 9 lightning_app 5 0
Additional details and impacted files
@@ Coverage Diff @@
## master #19024 +/- ##
==========================================
- Coverage 83% 47% -36%
==========================================
Files 445 437 -8
Lines 37289 37140 -149
==========================================
- Hits 31119 17586 -13533
- Misses 6170 19554 +13384
⚠️ GitGuardian has uncovered 2 secrets following the scan of your pull request.
Please consider investigating the findings and remediating the incidents. Failure to do so may lead to compromising the associated services or software components.
🔎 Detected hardcoded secrets in your pull request
| GitGuardian id | Secret | Commit | Filename | |
|---|---|---|---|---|
| - | Generic High Entropy Secret | 78fa3afdfbf964c19b4b2d36b91560698aa83178 | tests/tests_app/utilities/test_login.py | View secret |
| - | Base64 Basic Authentication | 78fa3afdfbf964c19b4b2d36b91560698aa83178 | tests/tests_app/utilities/test_login.py | View secret |
🛠 Guidelines to remediate hardcoded secrets
- Understand the implications of revoking this secret by investigating where it is used in your code.
- Replace and store your secret safely. Learn here the best practices.
- Revoke and rotate this secret.
- If possible, rewrite git history. Rewriting git history is not a trivial act. You might completely break other contributing developers' workflow and you risk accidentally deleting legitimate data.
To avoid such incidents in the future consider
- following these best practices for managing and storing secrets including API keys and other credentials
- install secret detection on pre-commit to catch secret before it leaves your machine and ease remediation.
🦉 GitGuardian detects secrets in your source code to help developers and security teams secure the modern development process. You are seeing this because you or someone else with access to this repository has authorized GitGuardian to scan your pull request.
Our GitHub checks need improvements? Share your feedbacks!