pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

Model does not update its weights

Open kopalja opened this issue 6 months ago • 3 comments

Bug description

Hi, I am using PyTorch lightning to implement some new optimization strategies using automatic_optimization=False. For certain setting my optimization strategy (using automatic_optimization=False) should yield the same results as using standard optimization process (automatic_optimization=True). However I could not make it work. My optimization process was returning slightly different results as using default optimization process. After a while I figured out that PyTorch lightning sometimes does not update the model weights when using the default automatic_optimization=True. I have put together minimal example in which model weights won't get updated on step 5. Model weights also won't get updated when using different hyper-parameters (e.g., batch-size, lr), only at different training step.

Am I missing something or does this look like a bug. Thanks!

What version are you seeing the problem on?

v2.4

How to reproduce the bug

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.convs = nn.ModuleList(
            [
                nn.Conv2d(1, 64, 3, 1),
                nn.Conv2d(64, 64, 3, 1),
                nn.Conv2d(64, 128, 3, 1),
            ]
        )
        self.fc1 = nn.Linear(128, 10)

    def forward(self, x, target):
        for conv in self.convs:
            x = conv(x)
            x = F.relu(x)
            x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        logits = F.log_softmax(x, dim=1)
        return F.nll_loss(logits, target)


class MRELoop(pl.LightningModule):
    def __init__(self):
        super(MRELoop, self).__init__()
        self.model = CNN()
        self.dataset = datasets.MNIST(root=".mnist_data", download=True, transform=transforms.ToTensor())
        self.previous_params = None

    def training_step(self, batch, batch_idx):
        # Check whether new model weights differs from previous ones
        params = torch.cat([param.view(-1) for param in self.model.parameters()])
        if self.previous_params is not None:
            num_different_values = (self.previous_params != params).sum().item()
            self.trainer.should_stop = num_different_values == 0
        else:
            num_different_values = None

        self.previous_params = params
        loss = self.model.forward(*batch)
        print(
            f"step {batch_idx} | diff weights: {num_different_values} | all weights: {params.numel()} | weights mean: {torch.mean(params)} | loss: {loss.item()}"
        )
        return loss

    def configure_optimizers(self):
        # Bug occurs also with different lr only at differnt training step
        return torch.optim.AdamW(self.parameters(), lr=2e-3)
        # return torch.optim.SGD(self.parameters(), lr=9e-4) # Also with SGD

    def train_dataloader(self):
        return DataLoader(self.dataset)


if __name__ == "__main__":
    torch.set_float32_matmul_precision("high")
    pl.seed_everything(1337)
    pl_trainer = pl.Trainer(
        precision="16-mixed",  # So far bug has occured only with 16-mixed
        deterministic=True,
        enable_progress_bar=False,
    )
    pl_trainer.fit(MRELoop())

Error messages and logs

/home/kopal/miniconda3/envs/overshoot/lib/python3.12/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python mvp.py ...
Using 16bit Automatic Mixed Precision (AMP)
/home/kopal/miniconda3/envs/overshoot/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/amp.py:52: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/kopal/miniconda3/envs/overshoot/lib/python3.12/site-packages/pytorch_lightning/loops/utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type | Params | Mode 
---------------------------------------
0 | model | CNN  | 112 K  | train
---------------------------------------
112 K     Trainable params
0         Non-trainable params
112 K     Total params
0.451     Total estimated model params size (MB)
/home/kopal/miniconda3/envs/overshoot/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
step 0 | diff weights: None | all weights: 112714 | weights mean: 1.6999114450300112e-05 | loss: 2.334902763366699
step 1 | diff weights: 112714 | all weights: 112714 | weights mean: 3.690078665385954e-05 | loss: 2.32588529586792
step 2 | diff weights: 112714 | all weights: 112714 | weights mean: -0.00010425636719446629 | loss: 2.621901512145996
step 3 | diff weights: 112714 | all weights: 112714 | weights mean: -0.00030326732667163014 | loss: 2.4029626846313477
step 4 | diff weights: 112714 | all weights: 112714 | weights mean: -0.0005236949073150754 | loss: 2.657553195953369
step 5 | diff weights: 0 | all weights: 112714 | weights mean: -0.0005236949073150754 | loss: 2.5822641849517822

Environment

Current environment
* CUDA:
	- GPU:
		- NVIDIA A100-PCIE-40GB
	- available:         True
	- version:           12.1
* Lightning:
	- lightning-utilities: 0.11.6
	- pytorch-lightning: 2.3.3
	- torch:             2.4.0
	- torchmetrics:      1.4.1
	- torchvision:       0.19.0
* Packages:
	- absl-py:           2.1.0
	- aiohappyeyeballs:  2.3.4
	- aiohttp:           3.10.1
	- aiosignal:         1.3.1
	- asttokens:         2.4.1
	- attrs:             24.1.0
	- autocommand:       2.2.2
	- backports.tarfile: 1.2.0
	- beautifulsoup4:    4.12.3
	- black:             24.8.0
	- certifi:           2024.7.4
	- charset-normalizer: 3.3.2
	- click:             8.1.7
	- comm:              0.2.2
	- datasets:          2.20.0
	- debugpy:           1.8.5
	- decorator:         5.1.1
	- dill:              0.3.8
	- exceptiongroup:    1.2.2
	- executing:         2.0.1
	- filelock:          3.15.4
	- frozenlist:        1.4.1
	- fsspec:            2024.5.0
	- gdown:             5.2.0
	- grpcio:            1.65.4
	- huggingface-hub:   0.24.5
	- idna:              3.7
	- importlib-metadata: 8.2.0
	- importlib-resources: 6.4.0
	- inflect:           7.3.1
	- ipykernel:         6.29.5
	- ipython:           8.26.0
	- isort:             5.13.2
	- jaraco.context:    5.3.0
	- jaraco.functools:  4.0.1
	- jaraco.text:       3.12.1
	- jedi:              0.19.1
	- jinja2:            3.1.4
	- jupyter-client:    8.6.2
	- jupyter-core:      5.7.2
	- lightning-utilities: 0.11.6
	- markdown:          3.6
	- markupsafe:        2.1.5
	- matplotlib-inline: 0.1.7
	- more-itertools:    10.3.0
	- mpmath:            1.3.0
	- multidict:         6.0.5
	- multiprocess:      0.70.16
	- mypy-extensions:   1.0.0
	- nest-asyncio:      1.6.0
	- networkx:          3.3
	- numpy:             2.0.1
	- nvidia-cublas-cu12: 12.1.3.1
	- nvidia-cuda-cupti-cu12: 12.1.105
	- nvidia-cuda-nvrtc-cu12: 12.1.105
	- nvidia-cuda-runtime-cu12: 12.1.105
	- nvidia-cudnn-cu12: 9.1.0.70
	- nvidia-cufft-cu12: 11.0.2.54
	- nvidia-curand-cu12: 10.3.2.106
	- nvidia-cusolver-cu12: 11.4.5.107
	- nvidia-cusparse-cu12: 12.1.0.106
	- nvidia-nccl-cu12:  2.20.5
	- nvidia-nvjitlink-cu12: 12.6.20
	- nvidia-nvtx-cu12:  12.1.105
	- ordered-set:       4.1.0
	- packaging:         24.1
	- pandas:            2.2.2
	- parso:             0.8.4
	- pathspec:          0.12.1
	- pexpect:           4.9.0
	- pickleshare:       0.7.5
	- pillow:            10.4.0
	- pip:               24.2
	- platformdirs:      4.2.2
	- prompt-toolkit:    3.0.47
	- protobuf:          4.25.4
	- psutil:            6.0.0
	- ptyprocess:        0.7.0
	- pure-eval:         0.2.3
	- pyarrow:           17.0.0
	- pyarrow-hotfix:    0.6
	- pygments:          2.18.0
	- pynvml:            11.5.3
	- pysocks:           1.7.1
	- python-dateutil:   2.9.0
	- pytorch-lightning: 2.3.3
	- pytz:              2024.1
	- pyyaml:            6.0.1
	- pyzmq:             26.1.0
	- regex:             2024.7.24
	- requests:          2.32.3
	- safetensors:       0.4.4
	- setuptools:        72.1.0
	- six:               1.16.0
	- soupsieve:         2.5
	- stack-data:        0.6.2
	- sympy:             1.13.1
	- tensorboard:       2.17.0
	- tensorboard-data-server: 0.7.2
	- tiktoken:          0.7.0
	- tokenizers:        0.19.1
	- tomli:             2.0.1
	- torch:             2.4.0
	- torchmetrics:      1.4.1
	- torchvision:       0.19.0
	- tornado:           6.4.1
	- tqdm:              4.66.5
	- traitlets:         5.14.3
	- transformers:      4.44.0
	- triton:            3.0.0
	- typeguard:         4.3.0
	- typing-extensions: 4.12.2
	- tzdata:            2024.1
	- urllib3:           2.2.2
	- wcwidth:           0.2.13
	- werkzeug:          3.0.3
	- wheel:             0.44.0
	- xxhash:            3.4.1
	- yarl:              1.9.4
	- zipp:              3.19.2
* System:
	- OS:                Linux
	- architecture:
		- 64bit
		- ELF
	- processor:         x86_64
	- python:            3.12.4
	- release:           3.10.0-1160.71.1.el7.x86_64
	- version:           #1 SMP Tue Jun 28 15:37:28 UTC 2022

More info

No response

kopalja avatar Aug 19 '24 15:08 kopalja