pytorch-lightning
pytorch-lightning copied to clipboard
Model does not update its weights
Bug description
Hi, I am using PyTorch lightning to implement some new optimization strategies using automatic_optimization=False
. For certain setting my optimization strategy (using automatic_optimization=False
) should yield the same results as using standard optimization process (automatic_optimization=True
). However I could not make it work. My optimization process was returning slightly different results as using default optimization process. After a while I figured out that PyTorch lightning sometimes does not update the model weights when using the default automatic_optimization=True
. I have put together minimal example in which model weights won't get updated on step 5. Model weights also won't get updated when using different hyper-parameters (e.g., batch-size, lr), only at different training step.
Am I missing something or does this look like a bug. Thanks!
What version are you seeing the problem on?
v2.4
How to reproduce the bug
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.convs = nn.ModuleList(
[
nn.Conv2d(1, 64, 3, 1),
nn.Conv2d(64, 64, 3, 1),
nn.Conv2d(64, 128, 3, 1),
]
)
self.fc1 = nn.Linear(128, 10)
def forward(self, x, target):
for conv in self.convs:
x = conv(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = torch.flatten(x, 1)
x = self.fc1(x)
logits = F.log_softmax(x, dim=1)
return F.nll_loss(logits, target)
class MRELoop(pl.LightningModule):
def __init__(self):
super(MRELoop, self).__init__()
self.model = CNN()
self.dataset = datasets.MNIST(root=".mnist_data", download=True, transform=transforms.ToTensor())
self.previous_params = None
def training_step(self, batch, batch_idx):
# Check whether new model weights differs from previous ones
params = torch.cat([param.view(-1) for param in self.model.parameters()])
if self.previous_params is not None:
num_different_values = (self.previous_params != params).sum().item()
self.trainer.should_stop = num_different_values == 0
else:
num_different_values = None
self.previous_params = params
loss = self.model.forward(*batch)
print(
f"step {batch_idx} | diff weights: {num_different_values} | all weights: {params.numel()} | weights mean: {torch.mean(params)} | loss: {loss.item()}"
)
return loss
def configure_optimizers(self):
# Bug occurs also with different lr only at differnt training step
return torch.optim.AdamW(self.parameters(), lr=2e-3)
# return torch.optim.SGD(self.parameters(), lr=9e-4) # Also with SGD
def train_dataloader(self):
return DataLoader(self.dataset)
if __name__ == "__main__":
torch.set_float32_matmul_precision("high")
pl.seed_everything(1337)
pl_trainer = pl.Trainer(
precision="16-mixed", # So far bug has occured only with 16-mixed
deterministic=True,
enable_progress_bar=False,
)
pl_trainer.fit(MRELoop())
Error messages and logs
/home/kopal/miniconda3/envs/overshoot/lib/python3.12/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python mvp.py ...
Using 16bit Automatic Mixed Precision (AMP)
/home/kopal/miniconda3/envs/overshoot/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/amp.py:52: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/kopal/miniconda3/envs/overshoot/lib/python3.12/site-packages/pytorch_lightning/loops/utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params | Mode
---------------------------------------
0 | model | CNN | 112 K | train
---------------------------------------
112 K Trainable params
0 Non-trainable params
112 K Total params
0.451 Total estimated model params size (MB)
/home/kopal/miniconda3/envs/overshoot/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
step 0 | diff weights: None | all weights: 112714 | weights mean: 1.6999114450300112e-05 | loss: 2.334902763366699
step 1 | diff weights: 112714 | all weights: 112714 | weights mean: 3.690078665385954e-05 | loss: 2.32588529586792
step 2 | diff weights: 112714 | all weights: 112714 | weights mean: -0.00010425636719446629 | loss: 2.621901512145996
step 3 | diff weights: 112714 | all weights: 112714 | weights mean: -0.00030326732667163014 | loss: 2.4029626846313477
step 4 | diff weights: 112714 | all weights: 112714 | weights mean: -0.0005236949073150754 | loss: 2.657553195953369
step 5 | diff weights: 0 | all weights: 112714 | weights mean: -0.0005236949073150754 | loss: 2.5822641849517822
Environment
Current environment
* CUDA:
- GPU:
- NVIDIA A100-PCIE-40GB
- available: True
- version: 12.1
* Lightning:
- lightning-utilities: 0.11.6
- pytorch-lightning: 2.3.3
- torch: 2.4.0
- torchmetrics: 1.4.1
- torchvision: 0.19.0
* Packages:
- absl-py: 2.1.0
- aiohappyeyeballs: 2.3.4
- aiohttp: 3.10.1
- aiosignal: 1.3.1
- asttokens: 2.4.1
- attrs: 24.1.0
- autocommand: 2.2.2
- backports.tarfile: 1.2.0
- beautifulsoup4: 4.12.3
- black: 24.8.0
- certifi: 2024.7.4
- charset-normalizer: 3.3.2
- click: 8.1.7
- comm: 0.2.2
- datasets: 2.20.0
- debugpy: 1.8.5
- decorator: 5.1.1
- dill: 0.3.8
- exceptiongroup: 1.2.2
- executing: 2.0.1
- filelock: 3.15.4
- frozenlist: 1.4.1
- fsspec: 2024.5.0
- gdown: 5.2.0
- grpcio: 1.65.4
- huggingface-hub: 0.24.5
- idna: 3.7
- importlib-metadata: 8.2.0
- importlib-resources: 6.4.0
- inflect: 7.3.1
- ipykernel: 6.29.5
- ipython: 8.26.0
- isort: 5.13.2
- jaraco.context: 5.3.0
- jaraco.functools: 4.0.1
- jaraco.text: 3.12.1
- jedi: 0.19.1
- jinja2: 3.1.4
- jupyter-client: 8.6.2
- jupyter-core: 5.7.2
- lightning-utilities: 0.11.6
- markdown: 3.6
- markupsafe: 2.1.5
- matplotlib-inline: 0.1.7
- more-itertools: 10.3.0
- mpmath: 1.3.0
- multidict: 6.0.5
- multiprocess: 0.70.16
- mypy-extensions: 1.0.0
- nest-asyncio: 1.6.0
- networkx: 3.3
- numpy: 2.0.1
- nvidia-cublas-cu12: 12.1.3.1
- nvidia-cuda-cupti-cu12: 12.1.105
- nvidia-cuda-nvrtc-cu12: 12.1.105
- nvidia-cuda-runtime-cu12: 12.1.105
- nvidia-cudnn-cu12: 9.1.0.70
- nvidia-cufft-cu12: 11.0.2.54
- nvidia-curand-cu12: 10.3.2.106
- nvidia-cusolver-cu12: 11.4.5.107
- nvidia-cusparse-cu12: 12.1.0.106
- nvidia-nccl-cu12: 2.20.5
- nvidia-nvjitlink-cu12: 12.6.20
- nvidia-nvtx-cu12: 12.1.105
- ordered-set: 4.1.0
- packaging: 24.1
- pandas: 2.2.2
- parso: 0.8.4
- pathspec: 0.12.1
- pexpect: 4.9.0
- pickleshare: 0.7.5
- pillow: 10.4.0
- pip: 24.2
- platformdirs: 4.2.2
- prompt-toolkit: 3.0.47
- protobuf: 4.25.4
- psutil: 6.0.0
- ptyprocess: 0.7.0
- pure-eval: 0.2.3
- pyarrow: 17.0.0
- pyarrow-hotfix: 0.6
- pygments: 2.18.0
- pynvml: 11.5.3
- pysocks: 1.7.1
- python-dateutil: 2.9.0
- pytorch-lightning: 2.3.3
- pytz: 2024.1
- pyyaml: 6.0.1
- pyzmq: 26.1.0
- regex: 2024.7.24
- requests: 2.32.3
- safetensors: 0.4.4
- setuptools: 72.1.0
- six: 1.16.0
- soupsieve: 2.5
- stack-data: 0.6.2
- sympy: 1.13.1
- tensorboard: 2.17.0
- tensorboard-data-server: 0.7.2
- tiktoken: 0.7.0
- tokenizers: 0.19.1
- tomli: 2.0.1
- torch: 2.4.0
- torchmetrics: 1.4.1
- torchvision: 0.19.0
- tornado: 6.4.1
- tqdm: 4.66.5
- traitlets: 5.14.3
- transformers: 4.44.0
- triton: 3.0.0
- typeguard: 4.3.0
- typing-extensions: 4.12.2
- tzdata: 2024.1
- urllib3: 2.2.2
- wcwidth: 0.2.13
- werkzeug: 3.0.3
- wheel: 0.44.0
- xxhash: 3.4.1
- yarl: 1.9.4
- zipp: 3.19.2
* System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.12.4
- release: 3.10.0-1160.71.1.el7.x86_64
- version: #1 SMP Tue Jun 28 15:37:28 UTC 2022
More info
No response