timesfm
timesfm copied to clipboard
Error in loading checkpoint
Background
Linux x86 timesfm cpu version use slurm to submit job. already ensure that conda env is activated after using SBATCH and before running python code
Code that ran into error
tfm = timesfm.TimesFm(
context_len=480,
horizon_len=14,
input_patch_len=32, # fixed
output_patch_len=128, # fixed
num_layers=20, # fixed
model_dims=1280, # fixed
backend="cpu",
)
tfm.load_from_checkpoint(checkpoint_path=my_checkpoint_path, step=1100000)
Description
I downloaded the model checkpoint from a huggingface mirror website, and stored to this path:
/repo/timesfm_model/checkpoints/checkpoint_1100000/state/checkpoint
. I'm not sure what is the right path to input checkpoint_path in tfm.load_from_checkpoint(checkpoint_path=my_checkpoint_path)
The questions i want to ask is
- What should be
my_checkpoint_path
in my case? I tried all possible choices and didn't work out, with error messages showing below - What is the size of the checkpoint for those of you having a working example? I use curl -L to download the checkpoint from a mirror website to the server, and the size is 777 Mb, which is weird as if I download from the same link directly to my local machine (Mac), the size is 814.3 Mb.
Error message
- When I use
my_checkpoint_path = "/repo/timesfm_model/checkpoints"
tfm.load_from_checkpoint(checkpoint_path=my_checkpoint_path)
the corresponding error message is:
WARNING:jax._src.xla_bridge:An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not in
stalled. Falling back to cpu.
WARNING:absl:No registered CheckpointArgs found for handler type: <class 'paxml.checkpoints.FlaxCheckpointHan
dler'>
WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at ht
tps://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by May 1st, 2024.
WARNING:absl:train_state_unpadded_shape_dtype_struct is not provided. We assume `train_state` is unpadded.
ERROR:absl:For checkpoint version > 1.0, we require users to provide
`train_state_unpadded_shape_dtype_struct` during checkpoint
saving/restoring, to avoid potential silent bugs when loading
checkpoints to incompatible unpadded shapes of TrainState.
- When i use
my_checkpoint_path = "/repo/timesfm_model/checkpoints/checkpoint_1100000"
# or my my_checkpoint_path = "/repo/timesfm_model/checkpoints/checkpoint_1100000/state"
tfm.load_from_checkpoint(checkpoint_path=my_checkpoint_path)
the error message is like this
WARNING:jax._src.xla_bridge:An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not in
stalled. Falling back to cpu.
WARNING:absl:No registered CheckpointArgs found for handler type: <class 'paxml.checkpoints.FlaxCheckpointHan
dler'>
WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at ht
tps://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by May 1st, 2024.
WARNING:absl:train_state_unpadded_shape_dtype_struct is not provided. We assume `train_state` is unpadded.
Constructing model weights.
Constructed model weights in 2.28 seconds.
Restoring checkpoint from /repo/timesfm_model/checkpoints/checkpoint_1100000.
Traceback (most recent call last):
File "/./bin/python_script/timesfm_pred.py", line 92, in <module>
main()
File "/./bin/python_script/timesfm_pred.py", line 72, in main
tfm.load_from_checkpoint(checkpoint_path="/ssd1/cache/hpc_t0/hy/repo/timesfm_model/checkpoints/checkpoint
_1100000", step=1100000)
File "repo/timesfm/src/timesfm.py", line 270, in load_from_checkpoint
self._train_state = checkpoints.restore_checkpoint(
File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/paxml/checkpoints.py", line 246, in
restore_checkpoint
output = checkpoint_manager.restore(
File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/paxml/checkpoint_managers.py", line
568, in restore
restored = self._manager.restore(
File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager
.py", line 1054, in restore
restore_directory = self._get_read_step_directory(step, directory)
File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager
.py", line 811, in _get_read_step_directory
return self._options.step_name_format.find_step(root_dir, step).path
File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/orbax/checkpoint/path/step.py", lin
e 66, in find_step
raise ValueError(
ValueError: No step path found for step=1100000 with NameFormat=PaxStepNameFormat(checkpoint_type=<Checkpoint
Type.FLAX: 'flax'>, use_digit_step_subdirectory=False) under /ssd1/cache/hpc_t0/hy/repo/timesfm_model/checkpo
ints/checkpoint_1100000
- When i use
my_checkpoint_path = "/repo/timesfm_model/checkpoints/checkpoint_1100000/state/checkpoint"
tfm.load_from_checkpoint(checkpoint_path=my_checkpoint_path)
the error message is
WARNING:jax._src.xla_bridge:An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not in
stalled. Falling back to cpu.
WARNING:absl:No registered CheckpointArgs found for handler type: <class 'paxml.checkpoints.FlaxCheckpointHan
dler'>
Constructing model weights.
Constructed model weights in 2.42 seconds.
Restoring checkpoint from /repo/timesfm_model/checkpoints/checkpoint_1100000/state/check
point.
Traceback (most recent call last):
File "/./bin/python_script/timesfm_pred.py", line 92, in <module>
main()
File "/./bin/python_script/timesfm_pred.py", line 72, in main
tfm.load_from_checkpoint(checkpoint_path="/ssd1/cache/hpc_t0/hy/repo/timesfm_model/checkpoints/checkpoint
_1100000/state/checkpoint", step=1100000)
File "/ssd1/cache/hpc_t0/hy/repo/timesfm/src/timesfm.py", line 270, in load_from_checkpoint
self._train_state = checkpoints.restore_checkpoint(
File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/paxml/checkpoints.py", line 227, in
restore_checkpoint
checkpoint_manager = checkpoint_managers.OrbaxCheckpointManager(
File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/paxml/checkpoint_managers.py", line
451, in __init__
self._manager = _CheckpointManagerImpl(
File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/paxml/checkpoint_managers.py", line
290, in __init__
step = self.any_step()
File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/paxml/checkpoint_managers.py", line
370, in any_step
any_step = ocp.utils.any_checkpoint_step(self.directory)
File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/orbax/checkpoint/utils.py", line 71
4, in any_checkpoint_step
for s in checkpoint_dir.iterdir():
File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/etils/epath/gpath.py", line 156, in
iterdir
for f in self._backend.listdir(self._path_str):
File "/envs/tfm_env/lib/python3.10/site-packages/etils/epath/backend.py", line 142,
in listdir
return [p for p in os.listdir(path) if not p.endswith('~')]
NotADirectoryError: [Errno 20] Not a directory: '/ssd1/cache/hpc_t0/hy/repo/timesfm_model/checkpoints/checkpo
int_1100000/state/checkpoint'