pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

Add support to `saving.py` for loading GPU-trained models on CPU-only machines

Open amorehead opened this issue 2 years ago • 13 comments

What does this PR do?

  • Adds support to saving.py for loading GPU-trained models on CPU-only machines.
  • Without this fix, a .to() call in the context of CPU-only inference may lead to AssertionError: Torch not compiled with CUDA enabled.
Before submitting
  • [ ] Was this discussed/agreed via a GitHub issue? (not for typos and docs)
  • [x] Did you read the contributor guideline, Pull Request section?
  • [x] Did you make sure your PR does only one thing, instead of bundling different changes together?
  • [x] Did you make sure to update the documentation with your changes? (if necessary)
  • [x] Did you write any new necessary tests? (not for typos and docs)
  • [ ] Did you verify new and existing tests pass locally with your changes?
  • [x] Did you list all the breaking changes introduced by this pull request?
  • [x] Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR. Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

Reviewer checklist
  • [x] Is this pull request ready for review? (if not, please submit in draft mode)
  • [x] Check that all items from Before submitting are resolved
  • [x] Make sure the title is self-explanatory and the description concisely explains the PR
  • [ ] Add labels and milestones (and optionally projects) to the PR so it can be classified

:books: Documentation preview :books:: https://pytorch-lightning--19024.org.readthedocs.build/en/19024/

amorehead avatar Nov 17 '23 16:11 amorehead

@amorehead Have you tried reporting this on PyTorch? You would expect that cpu_thing.to(cpu) is always a no-op

carmocca avatar Nov 17 '23 16:11 carmocca

@carmocca, great point. I'll open up an issue for PyTorch as well, linked to this one for Lightning. However, for the time being (since it may take a while for PyTorch to fix the issue on their end), I think this PR for Lightning should still be useful for the time being, in case other users run into the same issue I am facing.

amorehead avatar Nov 17 '23 19:11 amorehead

Yes, we can merge this, but I would like to hear from their team first before moving forward. Then we could have this:

if not _TORCH_GREATER_EQUAL_2_2:
    # your patch

carmocca avatar Nov 17 '23 20:11 carmocca

@amorehead Great find. If you still have it, could you provide the full stack trace of the error?

awaelchli avatar Nov 18 '23 14:11 awaelchli

@awaelchli, yes, the stack trace is as follows.

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/acmwhb/mambaforge/envs/GCPNet/lib/python3.10/site-packages/lightning/pytorch/core/module.py", line 1552, in load_from_checkpoint
    loaded = _load_from_checkpoint(
  File "/home/acmwhb/mambaforge/envs/GCPNet/lib/python3.10/site-packages/lightning/pytorch/core/saving.py", line 97, in _load_from_checkpoint
    return model.to(device)
  File "/home/acmwhb/mambaforge/envs/GCPNet/lib/python3.10/site-packages/lightning/fabric/utilities/device_dtype_mixin.py", line 54, in to
    return super().to(*args, **kwargs)
  File "/home/acmwhb/mambaforge/envs/GCPNet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1160, in to
    return self._apply(convert)
  File "/home/acmwhb/mambaforge/envs/GCPNet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 810, in _apply
    module._apply(fn)
  File "/home/acmwhb/mambaforge/envs/GCPNet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 810, in _apply
    module._apply(fn)
  File "/home/acmwhb/mambaforge/envs/GCPNet/lib/python3.10/site-packages/torchmetrics/metric.py", line 808, in _apply
    self._device = fn(torch.zeros(1, device=self.device)).device
  File "/home/acmwhb/mambaforge/envs/GCPNet/lib/python3.10/site-packages/torch/cuda/__init__.py", line 289, in _lazy_init
    raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled

This is triggered by calling:

my_lightning_model.__class__.load_from_checkpoint(
    checkpoint_path=ckpt_path,
    map_location="cpu",
    strict=True,
)

The issue happens with both Lightning 2.1.0 and 2.1.2 (note the __class__ bit for 2.1.2). When I install my patched version of Lightning (as packaged in this PR), this issue goes away by skipping these .to() calls altogether.

amorehead avatar Nov 18 '23 19:11 amorehead

Given the stack trace, we see that it goes through torchmetrics and fails at this line:

    self._device = fn(torch.zeros(1, device=self.device)).device

maybe self.device (for some reason) is cuda and not cpu? In any case, it would be good if we could identify if it's an issue PyTorch or metrics. I couldn't repro on my MacOS. @amorehead any change you could help here sanity checking that self.device is CPU is in this line above?

awaelchli avatar Nov 18 '23 20:11 awaelchli

This seems to be a torchmetrics bug, see discussion on the PyTorch issue tracker (https://github.com/pytorch/pytorch/issues/113973).

tringwald avatar Nov 19 '23 17:11 tringwald

@amorehead Did you actually end up with the entire torchmetric object pickled in the checkpoint like described by this user https://github.com/Lightning-AI/torchmetrics/issues/2223 or was it a proper checkpoint with the state dict of the metric? Because the former would indeed explain your issue, but then the fix should be not to pickle the metric in the first place.

awaelchli avatar Nov 24 '23 01:11 awaelchli

@awaelchli,

You have described it perfectly. The checkpoints I am trying to load on a CPU-only machine contain full TorchMetrics objects in them unintentionally. Seems this is not best practice by any means. Are you aware of any workarounds for this issue in light of the metrics being fully saved in my checkpoint files, or is the only solution to only save the state_dicts in the first place?

amorehead avatar Nov 29 '23 04:11 amorehead

Hi, I was having the same issue, and this commit fixed it for me! I would be very happy if this gets merged.

sfalkena avatar Dec 07 '23 19:12 sfalkena

Codecov Report

Attention: Patch coverage is 83.33333% with 1 line in your changes missing coverage. Please review.

Project coverage is 47%. Comparing base (bb14a97) to head (c856888). Report is 339 commits behind head on master.

:exclamation: There is a different number of reports uploaded between BASE (bb14a97) and HEAD (c856888). Click for more details.

HEAD has 205 uploads less than BASE
Flag BASE (bb14a97) HEAD (c856888)
lightning 44 15
cpu 74 24
pytest 56 0
python3.10 21 9
app 9 0
examples 9 0
gpu 4 0
lightning_fabric 10 0
python3.9 6 3
python3.11 15 6
python3.8 12 6
tpu 2 0
pytorch_lightning 10 9
lightning_app 5 0
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #19024      +/-   ##
==========================================
- Coverage      83%      47%     -36%     
==========================================
  Files         445      437       -8     
  Lines       37289    37140     -149     
==========================================
- Hits        31119    17586   -13533     
- Misses       6170    19554   +13384     

codecov[bot] avatar Dec 07 '23 22:12 codecov[bot]

⚠️ GitGuardian has uncovered 2 secrets following the scan of your pull request.

Please consider investigating the findings and remediating the incidents. Failure to do so may lead to compromising the associated services or software components.

🔎 Detected hardcoded secrets in your pull request
GitGuardian id Secret Commit Filename
- Generic High Entropy Secret 78fa3afdfbf964c19b4b2d36b91560698aa83178 tests/tests_app/utilities/test_login.py View secret
- Base64 Basic Authentication 78fa3afdfbf964c19b4b2d36b91560698aa83178 tests/tests_app/utilities/test_login.py View secret
🛠 Guidelines to remediate hardcoded secrets
  1. Understand the implications of revoking this secret by investigating where it is used in your code.
  2. Replace and store your secret safely. Learn here the best practices.
  3. Revoke and rotate this secret.
  4. If possible, rewrite git history. Rewriting git history is not a trivial act. You might completely break other contributing developers' workflow and you risk accidentally deleting legitimate data.

To avoid such incidents in the future consider


🦉 GitGuardian detects secrets in your source code to help developers and security teams secure the modern development process. You are seeing this because you or someone else with access to this repository has authorized GitGuardian to scan your pull request.

Our GitHub checks need improvements? Share your feedbacks!

gitguardian[bot] avatar Jan 16 '24 09:01 gitguardian[bot]