xla icon indicating copy to clipboard operation
xla copied to clipboard

Add optimizer priming for dist chkpt

Open jonb377 opened this issue 1 year ago • 0 comments

See also: https://github.com/pytorch/xla/issues/6546

The optimizer state must be primed before it can be restored. Optimizer state isn't materialized until the first optim.step call, so to restore optimizer state before resuming training, a dummy step is needed.

This PR introduces the prime_optimizer API, which will run a dummy optimizer step with zeroed gradients. The gradient sharding is copied from the parameters to ensure the resulting sharding is the same.

jonb377 avatar Feb 20 '24 21:02 jonb377