pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

Weird bug when setting `val_check_interval` dynamically in `setup()`

Open davidgill97 opened this issue 7 months ago • 4 comments

Bug description

I want to dynamically set val_check_interval based on the total number of training steps. Specifically, i calculate val_check_interval using self.trainer.estimated_stepping_batches // 10 in the setup() method, aiming for 10 validations.

When i assign a constant value to self.trainer.val_check_interval, it works as expected, but when I use the dynamic calculation (self.trainer.estimated_stepping_batches // 10), it doesn't seem to work, even though the calculated value is correct and everything else is identical. I also set self.trainer.check_val_every_n_epoch=None, as per documentation.

What could be causing this weird bug, and how can I ensure that the dynamically calculated value is applied properly?

What version are you seeing the problem on?

v2.5

Reproduced in studio

No response

How to reproduce the bug

class Dummy(LightningModule):
    def setup(self, stage: str):
        if self.trainer:
            # This does not seem to change validation interval 
            self.trainer.val_check_interval = self.trainer.estimated_stepping_batches // 100
            # But this does 
            self.trainer.val_check_interval = 10

trainer.fit(dummy)

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.5.0): 2.5.0.post0
#- PyTorch Version (e.g., 2.5): 2.6.0+cu126
#- Python version (e.g., 3.12): 3.10.11
#- OS (e.g., Linux): Windows 11
#- CUDA/cuDNN version: 12.8
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source): pip

More info

No response

davidgill97 avatar Jun 11 '25 12:06 davidgill97

Reproduced it. It seems val_check_interval and check_val_every_n_epoch need to be set at Trainer initialization to control validation scheduling. The FitLoop configures its schedule before LightningModule.setup(stage="fit") is called, so changes within setup() don't affect the loop's validation frequency.

MrAnayDongre avatar Jun 12 '25 00:06 MrAnayDongre

Just found another internal variable val_check_batch in FitLoop, and changing its value and val_check_interval to the dynamically calculated value does the job. This behavior is still weird, as assigning a constant value to val_check_interval works as expected..

davidgill97 avatar Jun 12 '25 01:06 davidgill97

Good find. Depending on @lantiga @Borda feedback, maybe we can add this in the documentation to guide future users.

MrAnayDongre avatar Jun 12 '25 05:06 MrAnayDongre

Reproduced it. It seems val_check_interval and check_val_every_n_epoch need to be set at Trainer initialization to control validation scheduling. The FitLoop configures its schedule before LightningModule.setup(stage="fit") is called, so changes within setup() don't affect the loop's validation frequency.

Yes, that sounds reasonable to be changed

Borda avatar Jun 13 '25 16:06 Borda

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!

stale[bot] avatar Jul 19 '25 05:07 stale[bot]

Just to follow up, setting the trainer arguments dynamically is not really a supported use case because the internal lightning logic is build around certain variables being available at certain times. As stated in the docstring of the setup hook the intended use case is making changes to the model not trainer: https://github.com/Lightning-AI/pytorch-lightning/blob/79dc82c66cc38523abaf71a42fc5ebc8e06fb015/src/lightning/pytorch/core/hooks.py#L421-L423

That said, in this specific case it is possible to set the validation checking interval dynamically but it requires not only setting val_check_interval but also the internal val_check_batch. A simple script:

import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader, TensorDataset


class TestModule(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(10, 1)
        self.validation_steps = []
    
    def setup(self, stage: str):
        if stage == "fit":
            desired_interval = max(1, self.trainer.estimated_stepping_batches // 10)
            
            self.trainer.val_check_interval = desired_interval
            self.trainer.val_check_batch = desired_interval

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())
    
    def train_dataloader(self):
        x, y = torch.randn(200, 10), torch.randn(200, 1)
        return DataLoader(TensorDataset(x, y), batch_size=10)
    
    def val_dataloader(self):
        x, y = torch.randn(50, 10), torch.randn(50, 1)
        return DataLoader(TensorDataset(x, y), batch_size=10)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        return torch.nn.functional.mse_loss(self.layer(x), y)
    
    def validation_step(self, batch, batch_idx):
        if batch_idx == 0:  # Record only once per validation run
            self.validation_steps.append(self.trainer.global_step)
        x, y = batch
        return torch.nn.functional.mse_loss(self.layer(x), y)


def main():
    model = TestModule()
    trainer = pl.Trainer(
        max_epochs=1,
        val_check_interval=100,
        check_val_every_n_epoch=None,
        enable_progress_bar=False,
        logger=False
    )
    
    trainer.fit(model)
    
    print(f"\nResults:")
    print(f"  Validation steps: {model.validation_steps}")
    print(f"  Total validations: {len(model.validation_steps)}")
    print(f"  Expected ~{trainer.estimated_stepping_batches // 10} step intervals")


if __name__ == "__main__":
    main()

Closing issue, but feel free to ping and reopen if necessary.

SkafteNicki avatar Oct 10 '25 08:10 SkafteNicki