`Trainer`'s `.init_module()` context does not initialize model on target device

Open jin-zhe opened this issue 4 months ago

Bug description

I refer to the documentation on which states "you can force PyTorch to create the model directly on the target device" when using the .init_module() context. However I have verified across different GPU machines that this is not the case. A simple code is provided below which prints out the model's device after initialization under the context. It always prints 'cpu'.

How to reproduce the bug

from torch import nn, optim
from pytorch_lightning import Trainer, LightningModule

class LitAutoEncoder(LightningModule):
  Model taken from
  Details unimportant
  def __init__(self):
    self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
    self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

  def training_step(self, batch, batch_idx):
    x, _ = batch
    x = x.view(x.size(0), -1)
    z = self.encoder(x)
    x_hat = self.decoder(z)
    loss = nn.functional.mse_loss(x_hat, x)
    return loss

  def configure_optimizers(self):
    optimizer = optim.Adam(self.parameters(), lr=1e-3)
    return optimizer

trainer = Trainer(accelerator='gpu', devices=[0])
with trainer.init_module():
  model = LitAutoEncoder()
  print(model.device) # => cpu

Current environment ` * CUDA: - GPU: - Tesla V100-SXM2-32GB-LS - Tesla V100-SXM2-32GB - Tesla V100-SXM2-32GB-LS - Tesla V100-SXM2-32GB-LS - Tesla V100-SXM2-32GB-LS - Tesla V100-SXM2-32GB-LS - Tesla V100-SXM2-32GB-LS - Tesla V100-SXM2-32GB-LS - available: True - version: 11.8 * Lightning: - lightning: 2.4.0 - lightning-utilities: 0.11.6 - open-clip-torch: 2.26.1 - pytorch-lightning: 2.4.0 - torch: 2.1.0 - torchaudio: 2.1.0 - torchmetrics: 1.4.0.post0 - torchvision: 0.16.0 * Packages: - aiohappyeyeballs: 2.3.5 - aiohttp: 3.10.3 - aiosignal: 1.3.1 - altair: 5.4.1 - antlr4-python3-runtime: 4.9.3 - appdirs: 1.4.4 - asttokens: 2.4.1 - async-timeout: 4.0.3 - attrs: 24.2.0 - autocommand: 2.2.2 - backports.tarfile: 1.2.0 - blinker: 1.8.2 - brotli: 1.1.0 - cachetools: 5.5.0 - certifi: 2024.8.30 - cffi: 1.17.0 - charset-normalizer: 3.3.2 - click: 8.1.7 - colorama: 0.4.6 - comm: 0.2.2 - datasets: 2.20.0 - debugpy: 1.8.5 - decorator: 5.1.1 - dill: 0.3.8 - docker-pycreds: 0.4.0 - einops: 0.8.0 - exceptiongroup: 1.2.2 - executing: 2.1.0 - filelock: 3.15.4 - frozenlist: 1.4.1 - fsspec: 2024.5.0 - ftfy: 6.2.3 - gitdb: 4.0.11 - gitpython: 3.1.43 - gmpy2: 2.1.5 - h2: 4.1.0 - hpack: 4.0.0 - huggingface-hub: 0.24.5 - hyperframe: 6.0.1 - idna: 3.7 - importlib-metadata: 7.2.1 - importlib-resources: 6.4.5 - inflect: 7.3.1 - ipykernel: 6.29.5 - ipython: 8.27.0 - ipywidgets: 8.1.5 - jaraco.context: 5.3.0 - jaraco.functools: 4.0.1 - jaraco.text: 3.12.1 - jedi: 0.19.1 - jinja2: 3.1.4 - jsonlines: 4.0.0 - jsonschema: 4.23.0 - jsonschema-specifications: 2023.12.1 - jupyter-client: 8.6.3 - jupyter-core: 5.7.2 - jupyterlab-widgets: 3.0.13 - lightning: 2.4.0 - lightning-utilities: 0.11.6 - markdown-it-py: 3.0.0 - markupsafe: 2.1.5 - matplotlib-inline: 0.1.7 - mdurl: 0.1.2 - more-itertools: 10.3.0 - mpmath: 1.3.0 - multidict: 6.0.5 - multiprocess: 0.70.16 - narwhals: 1.8.2 - nest-asyncio: 1.6.0 - networkx: 3.3 - numpy: 1.26.4 - omegaconf: 2.3.0 - open-clip-torch: 2.26.1 - opencv-python: 4.10.0 - opencv-python-headless: 4.10.0 - ordered-set: 4.1.0 - packaging: 24.1 - pandas: 2.2.2 - parso: 0.8.4 - pathtools: 0.1.2 - pexpect: 4.9.0 - pickleshare: 0.7.5 - pillow: 10.4.0 - pip: 24.2 - pkgutil-resolve-name: 1.3.10 - platformdirs: 4.3.6 - prompt-toolkit: 3.0.47 - protobuf: 4.25.3 - psutil: 6.0.0 - ptyprocess: 0.7.0 - pure-eval: 0.2.3 - pyarrow: 17.0.0 - pyarrow-hotfix: 0.6 - pycparser: 2.22 - pydeck: 0.8.0b4 - pygments: 2.18.0 - pysocks: 1.7.1 - python-dateutil: 2.9.0 - pytorch-lightning: 2.4.0 - pytz: 2024.1 - pyyaml: 6.0.2 - pyzmq: 26.2.0 - referencing: 0.35.1 - regex: 2024.7.24 - requests: 2.32.3 - rich: 13.8.1 - rpds-py: 0.20.0 - safetensors: 0.4.4 - sentry-sdk: 2.12.0 - setproctitle: 1.3.3 - setuptools: 72.1.0 - six: 1.16.0 - smmap: 5.0.0 - stack-data: 0.6.2 - streamlit: 1.38.0 - sympy: 1.13.2 - tenacity: 8.5.0 - timm: 1.0.8 - tokenizers: 0.19.1 - toml: 0.10.2 - tomli: 2.0.1 - torch: 2.1.0 - torchaudio: 2.1.0 - torchmetrics: 1.4.0.post0 - torchvision: 0.16.0 - tornado: 6.4.1 - tqdm: 4.66.5 - traitlets: 5.14.3 - transformers: 4.44.2 - triton: 2.1.0 - typeguard: 4.3.0 - typing-extensions: 4.12.2 - tzdata: 2024.1 - tzlocal: 5.2 - urllib3: 2.2.2 - validators: 0.34.0 - wandb: 0.16.6 - watchdog: 4.0.1 - wcwidth: 0.2.13 - wheel: 0.44.0 - widgetsnbextension: 4.0.13 - xformers: 0.0.22.post7 - xxhash: 3.4.1 - yarl: 1.9.4 - zipp: 3.20.2 - zstandard: 0.23.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.14 - release: 4.15.0-55-generic - version: #60-Ubuntu SMP Tue Jul 2 18:22:20 UTC 2019 ```

jin-zhe avatar Sep 27 '24 05:09 jin-zhe