pytorch-lightning
pytorch-lightning copied to clipboard
Loading a model changes pytorch random state
Bug description
Loading a model from a checkpoint model = TestLitModel.load_from_checkpoint(ckpt_path) changes the pytorch random state as I guess the model is initialized and then the weights are restored.
To me at least this was not clear from just calling a classmethod.
One option would be to restore the random state to what it was before the load, the other to add a hint in the documentation, if I'm not the only one for whom it is not intuitively clear.
What version are you seeing the problem on?
master
How to reproduce the bug
import time
from typing import Any
import lightning as L
import torch
from lightning.pytorch.callbacks import ModelCheckpoint, OnExceptionCheckpoint
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.utilities import grad_norm
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch import nn
from torch.utils.data import Dataset, DataLoader, Sampler
from src.main.ml.data.data_augmentation.helpers.random_numbers import create_rng_from_string
class TestModule(nn.Module):
def __init__(self, in_dim=512, out_dim=16):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.simple_layer = nn.Linear(self.in_dim, self.out_dim, bias=True)
def forward(self, input):
return self.simple_layer(input)
class TestBatchSampler(Sampler):
def __init__(self, step=0):
super().__init__()
self.step = step
def __len__(self) -> int:
return 1e100
# return len(self.train_allfiles)
def __iter__(self): # -> Iterator[int]:
return self
def __next__(self): # -> Iterator[int]:
return_value = self.step
self.step += 1
return [return_value]
class TestDataset(Dataset):
def __init__(self, in_dim):
super().__init__()
self.in_dim = in_dim
self.total_len = 512
def __len__(self):
return 1
def __getitem__(self, idx):
print(idx)
rng = create_rng_from_string(
str(idx) + "_"
+ "random_choice_sampler")
return torch.tensor(rng.random(self.in_dim), dtype=torch.float32)
# return torch.randn(self.in_dim)
class TestDataModule(L.LightningDataModule):
def __init__(self, start_step=0):
super().__init__()
self.in_dim = 512
self.val_batch_size = 1
self.start_step = start_step
def train_dataloader(self):
train_ds = TestDataset(self.in_dim)
train_dl = DataLoader(train_ds, batch_sampler=TestBatchSampler(step=self.start_step), num_workers=4,
shuffle=False)
return train_dl
def val_dataloader(self):
val_ds = TestDataset(self.in_dim)
val_dl = DataLoader(val_ds, batch_size=self.val_batch_size, num_workers=4, shuffle=False)
return val_dl
class TestLitModel(L.LightningModule):
def __init__(self):
super().__init__()
self.test_module_obj = TestModule(in_dim=512, out_dim=16)
self.automatic_optimization = False
def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None:
print(f"train_batch ended:{batch_idx}")
def on_save_checkpoint(self, checkpoint):
# my own fix for reproducible checkpointing
# checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['completed'] = \
# checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['processed']
# checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] = \
# checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['processed']
checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['completed'] = \
checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['processed']
checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] = \
checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['processed']
checkpoint['loops']['fit_loop']['epoch_loop.state_dict']['_batches_that_stepped'] = \
checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['processed']
print(f"creating checkpoint")
def validation_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
print(f"validation step called")
return torch.tensor(1.0)
def training_step(self, batch, batch_idx):
print(f"batch_idx: {batch_idx}")
time.sleep(0.25)
optimizer = self.optimizers()
output = self.test_module_obj(batch)
loss = output.sum()
self.manual_backward(loss)
optimizer.step()
self.log_dict({"loss": loss, "batch_idx": batch_idx})
norms = grad_norm(self.test_module_obj, norm_type=2)
self.log_dict(norms)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.test_module_obj.parameters()
)
return optimizer
if __name__ == '__main__':
test_data_loader = TestDataModule()
test_lit_model = TestLitModel()
checkpoint_dir = 'a_test_logs_2'
checkpoint_callback = ModelCheckpoint(
dirpath=checkpoint_dir,
every_n_train_steps=10,
save_top_k=-1, )
exception_checkpoint_callback = OnExceptionCheckpoint(
dirpath=checkpoint_dir,
filename="error"
)
wandb_logger = WandbLogger(save_dir="/home/thillmann/projects/metrabs_refactoring/a_logs",
project="global_step_test",
name="global_step_test",
notes="global_step_test",
log_model="all")
trainer = L.Trainer(
log_every_n_steps=1,
callbacks=[checkpoint_callback],
max_epochs=-1,
max_steps=25,
val_check_interval=5,
logger=wandb_logger,
)
# wandb_logger.watch(model=test_lit_model.test_module_obj, log="all", log_freq=1)
###########################################################
# uncomment this to create checkpoint
# trainer.fit(test_lit_model, test_data_loader)
###########################################################
cpu_random_state = torch.get_rng_state()
print("CPU Random State:", cpu_random_state)
ckpt_path = f'{checkpoint_dir}/epoch=0-step=10.ckpt'
print("CPU Random State equal:", torch.equal(cpu_random_state, torch.get_rng_state()))
ckpt = torch.load(ckpt_path)
global_step = ckpt['global_step']
ckpt_path = f'{checkpoint_dir}/epoch=0-step=10.ckpt'
print("CPU Random State equal after pytorch load:", torch.equal(cpu_random_state, torch.get_rng_state()))
model = TestLitModel.load_from_checkpoint(ckpt_path)
ckpt_path = f'{checkpoint_dir}/epoch=0-step=10.ckpt'
cpu_random_state_4 = torch.get_rng_state()
print("CPU Random State equal after lightning load:", torch.equal(cpu_random_state, cpu_random_state_4))
test_data_loader = TestDataModule(start_step=global_step)
trainer.fit(test_lit_model,
datamodule=test_data_loader,
ckpt_path=f'{checkpoint_dir}/epoch=0-step=10.ckpt')
Error messages and logs
# Error messages and logs here please
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.4.0):
#- PyTorch Version (e.g., 2.4):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
More info
No response