torchtitan
torchtitan copied to clipboard
Save RNG states during checkpointing for deterministic debugging
Context
Currently we are not saving random number generator status. For model like Flux, the inputs are randomly generated. If we saved checkpoint at step=x, and load at "step=x", the rng states will not be the same. So the generated noise (part of Flux input) is not the same, and the loss is not deterministic.
Example Implementation
In train.py
def state_dict(self) -> dict[str, Any]:
# Save training step and RNG states for reproducibility
device_module = utils.device_module
self.device_rng_state = device_module.get_rng_state()
self.cpu_rng_state = torch.get_rng_state()
def comput_rng_hash(state: torch.ByteTensor) -> float:
"""Compute a hash for the given state dictionary."""
return int.from_bytes(state.cpu().numpy().tobytes()[0:32])
logger.info(
f"In trainer.state_dict(), Read State dict RNG states at step {self.step}: CPU {comput_rng_hash(self.cpu_rng_state)} device {comput_rng_hash(self.device_rng_state)}"
)
return {
"step": self.step,
"device_rng_states": self.device_rng_state,
"cpu_rng_states": self.cpu_rng_state,
}
def load_state_dict(self, state_dict: dict[str, Any]):
self.step = state_dict["step"]
self.device_rng_state = state_dict["device_rng_states"]
self.cpu_rng_state = state_dict["cpu_rng_states"]
# Restore RNG states if they exist in the state_dict
device_module = utils.device_module
device_module.set_rng_state(self.device_rng_state)
torch.set_rng_state(self.cpu_rng_state)
def comput_rng_hash(state: torch.ByteTensor) -> float:
"""Compute a hash for the given state dictionary."""
return int.from_bytes(state.cpu().numpy().tobytes()[0:32])
logger.info(
f"Loaded State dict RNG states at step {self.step}: CPU {comput_rng_hash(self.cpu_rng_state)} device {comput_rng_hash(self.device_rng_state)}"
)
TODOs
The example implementation above is not "correct", because the rng state saved and loaded are not the same based on testing.