Remove the `optimizer_to_device` logic if possible
Outline & Motivation
The trainer uses a function optimizer_to_device here:
https://github.com/Lightning-AI/pytorch-lightning/blob/631911c00413ad028e2887d83eb264cb4822097e/src/lightning/pytorch/strategies/strategy.py#L160-L161
In #19955 an issue was raised that the function moved the "step" parameter in the optimizer state to the CUDA device, causing device-to-host syncs during optimizer.step() because the "step" tensor was expected to remain on CPU. #20019 fixed this with special treatment of that key. However, good arguments were made in #19955 that this optimizer_to_device shouldn't even be necessary in the first place (https://github.com/Lightning-AI/pytorch-lightning/issues/19955#issuecomment-2197353178).
Pitch
Remove optimizer_to_device and show that it is redundant by running the tests. We will still need a optimizer_to_cpu for teardown.
Additional context
No response
cc @justusschock @awaelchli @borda
Thanks @awaelchli . The analysis I have done on the code seems to agree that we can remove this function (see https://github.com/Lightning-AI/pytorch-lightning/issues/19955#issuecomment-2232309700). The remaining points I think are of note:
- This function is also called by
Strategy.teardown(). The idea is to transfer the optimizer back from the GPU to the CPU. Maybe we don't really care about this since when the fit is complete either the optimizer will be checkpointed or discarded. The questions is if users really depend on this final transfer behavior but I would guess not. This factors into whether we feel we needoptimizer_to_cpufor teardown as mentioned above. - I created additional tests in https://github.com/Lightning-AI/pytorch-lightning/pull/20062 that I think are helpful in verifying that the function is not needed. I believe that some of these may be helpful in a PR to remove this function.
Also tagging @janeyx99 for input.
- I am likely not the right person to respond to this q, but my instinct agrees with you and doesn't think moving the optimizer back to CPU needs to be a part of the teardown. I'd defer to others on this though.
- I reviewed the part of the tests that move the state_dict back and forth. Left my comments on the PR!