[TPU] Bug: Metadata path does not exist when using gs:
I am getting a number of errors checking for folder existence in _src/metadata/checkpoint.py e.g. _src/metadata/checkpoint.py", line 45, in _sanitize_metadata_path raise FileNotFoundError(f'Path does not exist: {path}') when trying to create a gs: checkpoint on TPU (v6e, v2-alpha-tpuv6e). For some reason the error does not happen elsewhere (e.g. locally on a Mac).
Sample code:
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint as ocp
path = "somebucket/somepath"
checkpointer = ocp.AsyncCheckpointer(ocp.PyTreeCheckpointHandler())
checkpointer.save(f"gs://{path}", (jnp.ones(10),), force=True)
print(checkpointer.restore(f"gs://{path}",(ocp.RestoreArgs(restore_type=np.ndarray),)))
Disabling those checks (patch attached) seems to resolve the issue, but is obviously more of a bandaid.
This is potentially looking like an issue that only surfaces on GCS, since you have not observed it in a local filesystem, and we cannot reproduce it on other distributed filesystems. Will get back with an update later when we have more info.
@bjenik is this a consistent issue or just flaky? I'm attempting to repro on a GCS environment but not getting any similar FileNotFoundErrors.
For me it's consistently happening every time - have been using that patch since the issue, but tried again just now without to see if it got fixed with a new version, but still happening unfortunately. Overall setup should be fairly vanilla (fresh TPU instance like you'd get when creating one in the console, with some basic packages installed pip install --upgrade 'jax[tpu]>0.3.0' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html && pip install --upgrade tqdm optax orbax-checkpoint equinox google-cloud-storage gcsfs tensorstore==0.1.71, tensorstore is pinned for an unrelated reason but it was happening even before doing that)
5f6ef63 appears to address this. Now working for me without my patch.