xla
xla copied to clipboard
Add optimizer priming for dist chkpt
See also: https://github.com/pytorch/xla/issues/6546
The optimizer state must be primed before it can be restored. Optimizer state isn't materialized until the first optim.step
call, so to restore optimizer state before resuming training, a dummy step is needed.
This PR introduces the prime_optimizer
API, which will run a dummy optimizer step with zeroed gradients. The gradient sharding is copied from the parameters to ensure the resulting sharding is the same.