Weird bug when setting `val_check_interval` dynamically in `setup()`
Bug description
I want to dynamically set val_check_interval based on the total number of training steps. Specifically, i calculate val_check_interval using self.trainer.estimated_stepping_batches // 10 in the setup() method, aiming for 10 validations.
When i assign a constant value to self.trainer.val_check_interval, it works as expected, but when I use the dynamic calculation (self.trainer.estimated_stepping_batches // 10), it doesn't seem to work, even though the calculated value is correct and everything else is identical. I also set self.trainer.check_val_every_n_epoch=None, as per documentation.
What could be causing this weird bug, and how can I ensure that the dynamically calculated value is applied properly?
What version are you seeing the problem on?
v2.5
Reproduced in studio
No response
How to reproduce the bug
class Dummy(LightningModule):
def setup(self, stage: str):
if self.trainer:
# This does not seem to change validation interval
self.trainer.val_check_interval = self.trainer.estimated_stepping_batches // 100
# But this does
self.trainer.val_check_interval = 10
trainer.fit(dummy)
Error messages and logs
# Error messages and logs here please
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.5.0): 2.5.0.post0
#- PyTorch Version (e.g., 2.5): 2.6.0+cu126
#- Python version (e.g., 3.12): 3.10.11
#- OS (e.g., Linux): Windows 11
#- CUDA/cuDNN version: 12.8
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source): pip
More info
No response
Reproduced it. It seems val_check_interval and check_val_every_n_epoch need to be set at Trainer initialization to control validation scheduling. The FitLoop configures its schedule before LightningModule.setup(stage="fit") is called, so changes within setup() don't affect the loop's validation frequency.
Just found another internal variable val_check_batch in FitLoop, and changing its value and val_check_interval to the dynamically calculated value does the job. This behavior is still weird, as assigning a constant value to val_check_interval works as expected..
Good find. Depending on @lantiga @Borda feedback, maybe we can add this in the documentation to guide future users.
Reproduced it. It seems
val_check_intervalandcheck_val_every_n_epochneed to be set at Trainer initialization to control validation scheduling. TheFitLoopconfigures its schedule beforeLightningModule.setup(stage="fit")is called, so changes within setup() don't affect the loop's validation frequency.
Yes, that sounds reasonable to be changed
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!
Just to follow up, setting the trainer arguments dynamically is not really a supported use case because the internal lightning logic is build around certain variables being available at certain times. As stated in the docstring of the setup hook the intended use case is making changes to the model not trainer:
https://github.com/Lightning-AI/pytorch-lightning/blob/79dc82c66cc38523abaf71a42fc5ebc8e06fb015/src/lightning/pytorch/core/hooks.py#L421-L423
That said, in this specific case it is possible to set the validation checking interval dynamically but it requires not only setting val_check_interval but also the internal val_check_batch. A simple script:
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader, TensorDataset
class TestModule(pl.LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(10, 1)
self.validation_steps = []
def setup(self, stage: str):
if stage == "fit":
desired_interval = max(1, self.trainer.estimated_stepping_batches // 10)
self.trainer.val_check_interval = desired_interval
self.trainer.val_check_batch = desired_interval
def configure_optimizers(self):
return torch.optim.Adam(self.parameters())
def train_dataloader(self):
x, y = torch.randn(200, 10), torch.randn(200, 1)
return DataLoader(TensorDataset(x, y), batch_size=10)
def val_dataloader(self):
x, y = torch.randn(50, 10), torch.randn(50, 1)
return DataLoader(TensorDataset(x, y), batch_size=10)
def training_step(self, batch, batch_idx):
x, y = batch
return torch.nn.functional.mse_loss(self.layer(x), y)
def validation_step(self, batch, batch_idx):
if batch_idx == 0: # Record only once per validation run
self.validation_steps.append(self.trainer.global_step)
x, y = batch
return torch.nn.functional.mse_loss(self.layer(x), y)
def main():
model = TestModule()
trainer = pl.Trainer(
max_epochs=1,
val_check_interval=100,
check_val_every_n_epoch=None,
enable_progress_bar=False,
logger=False
)
trainer.fit(model)
print(f"\nResults:")
print(f" Validation steps: {model.validation_steps}")
print(f" Total validations: {len(model.validation_steps)}")
print(f" Expected ~{trainer.estimated_stepping_batches // 10} step intervals")
if __name__ == "__main__":
main()
Closing issue, but feel free to ping and reopen if necessary.