torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

Save RNG states during checkpointing for deterministic debugging

Open wwwjn opened this issue 7 months ago • 0 comments

Context

Currently we are not saving random number generator status. For model like Flux, the inputs are randomly generated. If we saved checkpoint at step=x, and load at "step=x", the rng states will not be the same. So the generated noise (part of Flux input) is not the same, and the loss is not deterministic.

Example Implementation

In train.py

def state_dict(self) -> dict[str, Any]:
        # Save training step and RNG states for reproducibility

        device_module = utils.device_module
        self.device_rng_state = device_module.get_rng_state()
        self.cpu_rng_state = torch.get_rng_state()

        def comput_rng_hash(state: torch.ByteTensor) -> float:
            """Compute a hash for the given state dictionary."""
            return int.from_bytes(state.cpu().numpy().tobytes()[0:32])

        logger.info(
            f"In trainer.state_dict(), Read State dict RNG states at step {self.step}: CPU {comput_rng_hash(self.cpu_rng_state)} device {comput_rng_hash(self.device_rng_state)}"
        )

        return {
            "step": self.step,
            "device_rng_states": self.device_rng_state,
            "cpu_rng_states": self.cpu_rng_state,
        }

    def load_state_dict(self, state_dict: dict[str, Any]):
        self.step = state_dict["step"]
        self.device_rng_state = state_dict["device_rng_states"]
        self.cpu_rng_state = state_dict["cpu_rng_states"]

        # Restore RNG states if they exist in the state_dict
        device_module = utils.device_module
        device_module.set_rng_state(self.device_rng_state)
        torch.set_rng_state(self.cpu_rng_state)

        def comput_rng_hash(state: torch.ByteTensor) -> float:
            """Compute a hash for the given state dictionary."""
            return int.from_bytes(state.cpu().numpy().tobytes()[0:32])

        logger.info(
            f"Loaded State dict RNG states at step {self.step}: CPU {comput_rng_hash(self.cpu_rng_state)} device {comput_rng_hash(self.device_rng_state)}"
        )

TODOs

The example implementation above is not "correct", because the rng state saved and loaded are not the same based on testing.

wwwjn avatar May 14 '25 19:05 wwwjn