pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

Loading a model changes pytorch random state

Open heth27 opened this issue 1 year ago • 0 comments

Bug description

Loading a model from a checkpoint model = TestLitModel.load_from_checkpoint(ckpt_path) changes the pytorch random state as I guess the model is initialized and then the weights are restored. To me at least this was not clear from just calling a classmethod.

One option would be to restore the random state to what it was before the load, the other to add a hint in the documentation, if I'm not the only one for whom it is not intuitively clear.

What version are you seeing the problem on?

master

How to reproduce the bug

import time
from typing import Any

import lightning as L
import torch
from lightning.pytorch.callbacks import ModelCheckpoint, OnExceptionCheckpoint
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.utilities import grad_norm
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch import nn
from torch.utils.data import Dataset, DataLoader, Sampler

from src.main.ml.data.data_augmentation.helpers.random_numbers import create_rng_from_string


class TestModule(nn.Module):
    def __init__(self, in_dim=512, out_dim=16):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.simple_layer = nn.Linear(self.in_dim, self.out_dim, bias=True)

    def forward(self, input):
        return self.simple_layer(input)


class TestBatchSampler(Sampler):
    def __init__(self, step=0):
        super().__init__()
        self.step = step

    def __len__(self) -> int:
        return 1e100
        # return len(self.train_allfiles)

    def __iter__(self):  # -> Iterator[int]:
        return self

    def __next__(self):  # -> Iterator[int]:
        return_value = self.step
        self.step += 1
        return [return_value]


class TestDataset(Dataset):
    def __init__(self, in_dim):
        super().__init__()
        self.in_dim = in_dim
        self.total_len = 512

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        print(idx)
        rng = create_rng_from_string(
            str(idx) + "_"
            + "random_choice_sampler")
        return torch.tensor(rng.random(self.in_dim), dtype=torch.float32)
        # return torch.randn(self.in_dim)


class TestDataModule(L.LightningDataModule):
    def __init__(self, start_step=0):
        super().__init__()
        self.in_dim = 512
        self.val_batch_size = 1
        self.start_step = start_step

    def train_dataloader(self):
        train_ds = TestDataset(self.in_dim)
        train_dl = DataLoader(train_ds, batch_sampler=TestBatchSampler(step=self.start_step), num_workers=4,
                              shuffle=False)
        return train_dl

    def val_dataloader(self):
        val_ds = TestDataset(self.in_dim)
        val_dl = DataLoader(val_ds, batch_size=self.val_batch_size, num_workers=4, shuffle=False)
        return val_dl


class TestLitModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.test_module_obj = TestModule(in_dim=512, out_dim=16)
        self.automatic_optimization = False

    def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None:
        print(f"train_batch ended:{batch_idx}")

    def on_save_checkpoint(self, checkpoint):

        # my own fix for reproducible checkpointing
        # checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['completed'] = \
        #     checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['processed']
        # checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] = \
        #     checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['processed']

        checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['completed'] = \
            checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['processed']
        checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] = \
            checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['processed']
        checkpoint['loops']['fit_loop']['epoch_loop.state_dict']['_batches_that_stepped'] = \
            checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['processed']
        print(f"creating checkpoint")

    def validation_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
        print(f"validation step called")
        return torch.tensor(1.0)

    def training_step(self, batch, batch_idx):
        print(f"batch_idx: {batch_idx}")

        time.sleep(0.25)
        optimizer = self.optimizers()

        output = self.test_module_obj(batch)

        loss = output.sum()

        self.manual_backward(loss)

        optimizer.step()

        self.log_dict({"loss": loss, "batch_idx": batch_idx})
        norms = grad_norm(self.test_module_obj, norm_type=2)
        self.log_dict(norms)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.test_module_obj.parameters()
        )
        return optimizer


if __name__ == '__main__':
    test_data_loader = TestDataModule()
    test_lit_model = TestLitModel()

    checkpoint_dir = 'a_test_logs_2'

    checkpoint_callback = ModelCheckpoint(
        dirpath=checkpoint_dir,
        every_n_train_steps=10,
        save_top_k=-1, )
    exception_checkpoint_callback = OnExceptionCheckpoint(
        dirpath=checkpoint_dir,
        filename="error"
    )

    wandb_logger = WandbLogger(save_dir="/home/thillmann/projects/metrabs_refactoring/a_logs",
                               project="global_step_test",
                               name="global_step_test",
                               notes="global_step_test",
                               log_model="all")

    trainer = L.Trainer(
        log_every_n_steps=1,
        callbacks=[checkpoint_callback],
        max_epochs=-1,
        max_steps=25,
        val_check_interval=5,
        logger=wandb_logger,

    )

    # wandb_logger.watch(model=test_lit_model.test_module_obj, log="all", log_freq=1)

    ###########################################################

    # uncomment this to create checkpoint
    # trainer.fit(test_lit_model, test_data_loader)

    ###########################################################

    cpu_random_state = torch.get_rng_state()
    print("CPU Random State:", cpu_random_state)

    ckpt_path = f'{checkpoint_dir}/epoch=0-step=10.ckpt'
    print("CPU Random State equal:", torch.equal(cpu_random_state, torch.get_rng_state()))

    ckpt = torch.load(ckpt_path)
    global_step = ckpt['global_step']

    ckpt_path = f'{checkpoint_dir}/epoch=0-step=10.ckpt'
    print("CPU Random State equal after pytorch load:", torch.equal(cpu_random_state, torch.get_rng_state()))

    model = TestLitModel.load_from_checkpoint(ckpt_path)

    ckpt_path = f'{checkpoint_dir}/epoch=0-step=10.ckpt'
    cpu_random_state_4 = torch.get_rng_state()
    print("CPU Random State equal after lightning load:", torch.equal(cpu_random_state, cpu_random_state_4))

    test_data_loader = TestDataModule(start_step=global_step)

    trainer.fit(test_lit_model,
                datamodule=test_data_loader,
                ckpt_path=f'{checkpoint_dir}/epoch=0-step=10.ckpt')

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.4.0):
#- PyTorch Version (e.g., 2.4):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):

More info

No response

heth27 avatar Aug 15 '24 13:08 heth27