pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

Step when validation happens drifts for `val_check_interval` when gradient accumulation turned on

Open hrukalive opened this issue 2 years ago • 4 comments

Bug description

First of all, my task relies on step count instead of epochs. So I am doing validation checks by steps and saving checkpoints after that. However, as I turned gradient accumulation on, and the batch count is not divisible, I encountered weird drifts for the actual step when the validation is performed, and thus the checkpointing.

In the example below, I override the _save_checkpoint function to monitor the actual file name and it turns out to be drifting. My general setting is val_check_interval=accumulation*5 to make it validate every 5 effective optimizer steps, accumulation=3 and #batches=67 so there is one batch leftover.

How to reproduce the bug

import numpy as np
import pathlib

import time
import torch
import torch.nn as nn
import torch.optim

import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint

class Quadratic(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.a = nn.Parameter(torch.tensor(0.0))
        self.b = nn.Parameter(torch.tensor(0.0))
        self.c = nn.Parameter(torch.tensor(0.0))

    def forward(self, x):
        time.sleep(0.02)
        return self.a * x * x + self.b * x + self.c
    
    def _common_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.mse_loss(y_hat, y)
        return loss 

    def training_step(self, batch, batch_idx):
        loss = self._common_step(batch, batch_idx)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._common_step(batch, batch_idx)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

class CustomModelCheckpoint(ModelCheckpoint):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
        monitor_candidates = {k: v.item() if isinstance(v, torch.Tensor) else v for k, v in self._monitor_candidates(trainer).items()}
        print("\n", "Save checkpoint, global_step: ", trainer.global_step, pathlib.Path(filepath).stem, "monitor_candidates: " + str(monitor_candidates), "\n", flush=True)
        
    def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
        # print("Remove checkpoint: ", filepath, flush=True)
        pass
    
if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('-a', type=float, default=2.0)
    parser.add_argument('-b', type=float, default=3.0)
    parser.add_argument('-c', type=float, default=4.0)
    parser.add_argument('--epoch', type=int, default=500)
    args = parser.parse_args()

    x = torch.from_numpy(np.random.uniform(-10, 10, 2144)).float() # Make 67 batches
    y = args.a * x * x + args.b * x + args.c
    x2 = torch.from_numpy(np.random.uniform(-10, 10, 100)).float()
    y2 = args.a * x2 * x2 + args.b * x2 + args.c

    dataset = torch.utils.data.TensorDataset(x, y)
    val_dataset = torch.utils.data.TensorDataset(x2, y2)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)

    model = Quadratic()
    
    ####
    accumulate_grad_batches = 3
    val_check_interval = 5 * accumulate_grad_batches # to make interval for effective batches
    ####

    trainer = pl.Trainer(max_epochs=args.epoch, accelerator='cpu', callbacks=[CustomModelCheckpoint(
                    dirpath='.',
                    filename='steps_{step}',
                    monitor='step',
                    mode='max',
                    save_last=False,
                    save_top_k=5
                )],
            val_check_interval=val_check_interval,
            check_val_every_n_epoch=None,
            num_sanity_val_steps=0,
            accumulate_grad_batches=accumulate_grad_batches)
    trainer.fit(model, dataloader, val_dataloader)
    
    # Print the results
    print("a = ", model.a.item())
    print("b = ", model.b.item())
    print("c = ", model.c.item())

Error messages and logs

Save checkpoint, global_step:  5 steps_step=5 monitor_candidates: {'epoch': 0, 'step': 5}
Save checkpoint, global_step:  10 steps_step=10 monitor_candidates: {'epoch': 0, 'step': 10}
Save checkpoint, global_step:  15 steps_step=15 monitor_candidates: {'epoch': 0, 'step': 15}
Save checkpoint, global_step:  20 steps_step=20 monitor_candidates: {'epoch': 0, 'step': 20}
Save checkpoint, global_step:  25 steps_step=25 monitor_candidates: {'epoch': 1, 'step': 25}
Save checkpoint, global_step:  30 steps_step=30 monitor_candidates: {'epoch': 1, 'step': 30}
Save checkpoint, global_step:  35 steps_step=35 monitor_candidates: {'epoch': 1, 'step': 35}
Save checkpoint, global_step:  40 steps_step=40 monitor_candidates: {'epoch': 1, 'step': 40}

Save checkpoint, global_step:  46 steps_step=46 monitor_candidates: {'epoch': 2, 'step': 46}  <-- drift
Save checkpoint, global_step:  51 steps_step=51 monitor_candidates: {'epoch': 2, 'step': 51}
Save checkpoint, global_step:  56 steps_step=56 monitor_candidates: {'epoch': 2, 'step': 56}

Environment

Current environment
* CUDA:
        - GPU:
                - NVIDIA RTX A5000
                - NVIDIA RTX A5000
                - NVIDIA RTX A5000
                - NVIDIA RTX A5000
        - available:         True
        - version:           11.7
* Lightning:
        - lightning:         2.0.0
        - lightning-cloud:   0.5.32
        - lightning-lite:    1.8.6
        - lightning-utilities: 0.8.0
        - pytorch-lightning: 2.0.0
        - torch:             1.13.1
        - torchaudio:        0.13.1
        - torchcrepe:        0.0.17
        - torchmetrics:      0.11.4
        - torchvision:       0.14.1
* Packages:
        - absl-py:           1.3.0
        - aiobotocore:       2.4.2
        - aiohttp:           3.8.4
        - aioitertools:      0.11.0
        - aiosignal:         1.3.1
        - altgraph:          0.17.3
        - anyio:             3.6.2
        - appdirs:           1.4.4
        - arrow:             1.2.3
        - async-timeout:     4.0.2
        - attrs:             22.2.0
        - audioread:         3.0.0
        - backcall:          0.2.0
        - beautifulsoup4:    4.12.0
        - blessed:           1.20.0
        - blinker:           1.4
        - botocore:          1.27.59
        - brotlipy:          0.7.0
        - cachetools:        5.3.0
        - certifi:           2022.12.7
        - cffi:              1.15.1
        - charset-normalizer: 2.0.4
        - click:             8.1.3
        - contourpy:         1.0.7
        - croniter:          1.3.8
        - cryptography:      39.0.1
        - cycler:            0.11.0
        - dateutils:         0.6.12
        - decorator:         5.1.1
        - deepdiff:          6.3.0
        - distance:          0.1.3
        - dnspython:         2.3.0
        - einops:            0.6.0
        - email-validator:   1.3.1
        - et-xmlfile:        1.0.1
        - fastapi:           0.88.0
        - fire:              0.5.0
        - flit-core:         3.8.0
        - fonttools:         4.39.2
        - frozenlist:        1.3.3
        - fsspec:            2023.3.0
        - future:            0.18.2
        - g2p-en:            2.1.0
        - g2pm:              0.1.2.5
        - google-auth:       2.16.3
        - google-auth-oauthlib: 0.4.6
        - grpcio:            1.51.3
        - h11:               0.14.0
        - h5py:              3.7.0
        - httpcore:          0.16.3
        - httptools:         0.5.0
        - httpx:             0.23.3
        - idna:              3.4
        - imageio:           2.23.0
        - importlib-metadata: 6.1.0
        - inflect:           6.0.2
        - inquirer:          3.1.3
        - itsdangerous:      2.1.2
        - jinja2:            3.1.2
        - jmespath:          1.0.1
        - joblib:            1.2.0
        - kiwisolver:        1.4.4
        - librosa:           0.9.1
        - lightning:         2.0.0
        - lightning-cloud:   0.5.32
        - lightning-lite:    1.8.6
        - lightning-utilities: 0.8.0
        - llvmlite:          0.39.1
        - markdown:          3.4.3
        - markdown-it-py:    2.2.0
        - markupsafe:        2.1.2
        - matplotlib:        3.6.2
        - mdurl:             0.1.2
        - mkl-fft:           1.3.1
        - mkl-random:        1.2.2
        - mkl-service:       2.4.0
        - multidict:         6.0.4
        - networkx:          3.0
        - nltk:              3.8.1
        - numba:             0.56.4
        - numpy:             1.23.5
        - oauthlib:          3.2.2
        - ordered-set:       4.1.0
        - orjson:            3.8.8
        - packaging:         23.0
        - pillow:            9.4.0
        - pip:               23.0.1
        - platformdirs:      3.1.1
        - pooch:             1.7.0
        - praat-parselmouth: 0.4.3
        - protobuf:          3.13.0
        - psutil:            5.9.4
        - pyasn1:            0.4.8
        - pyasn1-modules:    0.2.8
        - pycparser:         2.21
        - pycwt:             0.3.0a22
        - pydantic:          1.10.7
        - pygments:          2.14.0
        - pyjwt:             2.6.0
        - pyloudnorm:        0.1.0
        - pyopenssl:         23.0.0
        - pyparsing:         3.0.9
        - pypinyin:          0.39.0
        - pysocks:           1.7.1
        - python-dateutil:   2.8.2
        - python-dotenv:     1.0.0
        - python-editor:     1.0.4
        - python-levenshtein: 0.12.2
        - python-multipart:  0.0.6
        - pytorch-lightning: 2.0.0
        - pytz:              2022.7.1
        - pywavelets:        1.4.1
        - pyyaml:            6.0
        - readchar:          4.0.5
        - regex:             2023.3.23
        - requests:          2.28.1
        - requests-oauthlib: 1.3.1
        - resampy:           0.4.2
        - resemblyzer:       0.1.1.dev0
        - rfc3986:           1.5.0
        - rich:              13.3.2
        - rsa:               4.9
        - s3fs:              2023.3.0
        - scikit-image:      0.19.3
        - scikit-learn:      1.2.2
        - scipy:             1.9.3
        - setuptools:        65.6.3
        - six:               1.16.0
        - snakeviz:          2.1.1
        - sniffio:           1.3.0
        - soundfile:         0.12.1
        - soupsieve:         2.4
        - starlette:         0.22.0
        - starsessions:      1.3.0
        - tensorboard:       2.11.0
        - tensorboard-data-server: 0.6.1
        - tensorboard-plugin-wit: 1.8.1
        - tensorboardx:      2.6
        - termcolor:         2.2.0
        - threadpoolctl:     3.1.0
        - tifffile:          2023.3.21
        - torch:             1.13.1
        - torchaudio:        0.13.1
        - torchcrepe:        0.0.17
        - torchmetrics:      0.11.4
        - torchvision:       0.14.1
        - tornado:           6.2
        - tqdm:              4.65.0
        - traitlets:         5.9.0
        - typing:            3.7.4.3
        - typing-extensions: 4.4.0
        - ujson:             5.7.0
        - urllib3:           1.26.14
        - uvicorn:           0.21.1
        - uvloop:            0.17.0
        - watchfiles:        0.18.1
        - wcwidth:           0.2.6
        - webrtcvad:         2.0.10
        - websocket-client:  1.5.1
        - websockets:        10.4
        - werkzeug:          2.2.3
        - wheel:             0.38.4
        - wrapt:             1.15.0
        - yarl:              1.8.2
        - zipp:              3.15.0
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - ELF
        - processor:         x86_64
        - python:            3.9.16
        - version:           #153-Ubuntu SMP Thu Nov 24 15:56:58 UTC 2022

More info

Other than this phenomenon, I have two more questions

  1. Why is val_check_interval tied to the number of batches rather than global_step?
  2. Why is validation re-run after loading a checkpoint just saved after the validation step? This is also going to produce a duplicate checkpoint, which is very frustrating

cc @carmocca @justusschock

hrukalive avatar Mar 27 '23 17:03 hrukalive

Amaze, thank you!

idan avatar May 09 '24 18:05 idan