jonb377
jonb377
See also: https://github.com/pytorch/xla/issues/6546 The optimizer state must be primed before it can be restored. Optimizer state isn't materialized until the first `optim.step` call, so to restore optimizer state before resuming...
## Fixes / Features - `debug-dump-gcs` doesn't need to be exclusive with environment-specified `XLA_FLAGS`. ## Testing / Documentation Testing details: - `xpk workload create ... --debug-dump-gcs gs://foo/bar --env XLA_FLAGS=--xla_dump_to=/foo/bar` =>...
Nightly tests are failing due to jax.distributed not being initialized in the synchronous checkpointing case.