timesfm icon indicating copy to clipboard operation
timesfm copied to clipboard

Error in loading checkpoint

Open yanghou2000 opened this issue 9 months ago • 2 comments

Background

Linux x86 timesfm cpu version use slurm to submit job. already ensure that conda env is activated after using SBATCH and before running python code

Code that ran into error

tfm = timesfm.TimesFm(
        context_len=480,
        horizon_len=14,
        input_patch_len=32, # fixed
        output_patch_len=128, # fixed
        num_layers=20, # fixed
        model_dims=1280, # fixed
        backend="cpu",
        )
    
tfm.load_from_checkpoint(checkpoint_path=my_checkpoint_path, step=1100000)

Description

I downloaded the model checkpoint from a huggingface mirror website, and stored to this path: /repo/timesfm_model/checkpoints/checkpoint_1100000/state/checkpoint. I'm not sure what is the right path to input checkpoint_path in tfm.load_from_checkpoint(checkpoint_path=my_checkpoint_path)

The questions i want to ask is

  1. What should be my_checkpoint_path in my case? I tried all possible choices and didn't work out, with error messages showing below
  2. What is the size of the checkpoint for those of you having a working example? I use curl -L to download the checkpoint from a mirror website to the server, and the size is 777 Mb, which is weird as if I download from the same link directly to my local machine (Mac), the size is 814.3 Mb.

Error message

  1. When I use
my_checkpoint_path = "/repo/timesfm_model/checkpoints"
tfm.load_from_checkpoint(checkpoint_path=my_checkpoint_path)

the corresponding error message is:

WARNING:jax._src.xla_bridge:An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not in
stalled. Falling back to cpu.
WARNING:absl:No registered CheckpointArgs found for handler type: <class 'paxml.checkpoints.FlaxCheckpointHan
dler'>
WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at ht
tps://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by May 1st, 2024.
WARNING:absl:train_state_unpadded_shape_dtype_struct is not provided. We assume `train_state` is unpadded.
ERROR:absl:For checkpoint version > 1.0, we require users to provide
          `train_state_unpadded_shape_dtype_struct` during checkpoint
          saving/restoring, to avoid potential silent bugs when loading
          checkpoints to incompatible unpadded shapes of TrainState.
  1. When i use
my_checkpoint_path = "/repo/timesfm_model/checkpoints/checkpoint_1100000" 
# or my my_checkpoint_path = "/repo/timesfm_model/checkpoints/checkpoint_1100000/state" 
tfm.load_from_checkpoint(checkpoint_path=my_checkpoint_path)

the error message is like this

WARNING:jax._src.xla_bridge:An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not in
stalled. Falling back to cpu.
WARNING:absl:No registered CheckpointArgs found for handler type: <class 'paxml.checkpoints.FlaxCheckpointHan
dler'>
WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at ht
tps://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by May 1st, 2024.
WARNING:absl:train_state_unpadded_shape_dtype_struct is not provided. We assume `train_state` is unpadded.
Constructing model weights.
Constructed model weights in 2.28 seconds.
Restoring checkpoint from /repo/timesfm_model/checkpoints/checkpoint_1100000.
Traceback (most recent call last):
  File "/./bin/python_script/timesfm_pred.py", line 92, in <module>
    main()
  File "/./bin/python_script/timesfm_pred.py", line 72, in main
    tfm.load_from_checkpoint(checkpoint_path="/ssd1/cache/hpc_t0/hy/repo/timesfm_model/checkpoints/checkpoint
_1100000", step=1100000)
  File "repo/timesfm/src/timesfm.py", line 270, in load_from_checkpoint
    self._train_state = checkpoints.restore_checkpoint(
  File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/paxml/checkpoints.py", line 246, in
 restore_checkpoint
    output = checkpoint_manager.restore(
  File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/paxml/checkpoint_managers.py", line
 568, in restore
    restored = self._manager.restore(
  File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager
.py", line 1054, in restore
    restore_directory = self._get_read_step_directory(step, directory)
  File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager
.py", line 811, in _get_read_step_directory
    return self._options.step_name_format.find_step(root_dir, step).path
  File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/orbax/checkpoint/path/step.py", lin
e 66, in find_step
    raise ValueError(
ValueError: No step path found for step=1100000 with NameFormat=PaxStepNameFormat(checkpoint_type=<Checkpoint
Type.FLAX: 'flax'>, use_digit_step_subdirectory=False) under /ssd1/cache/hpc_t0/hy/repo/timesfm_model/checkpo
ints/checkpoint_1100000
  1. When i use
my_checkpoint_path = "/repo/timesfm_model/checkpoints/checkpoint_1100000/state/checkpoint" 
tfm.load_from_checkpoint(checkpoint_path=my_checkpoint_path)

the error message is

WARNING:jax._src.xla_bridge:An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not in
stalled. Falling back to cpu.
WARNING:absl:No registered CheckpointArgs found for handler type: <class 'paxml.checkpoints.FlaxCheckpointHan
dler'>
Constructing model weights.
Constructed model weights in 2.42 seconds.
Restoring checkpoint from /repo/timesfm_model/checkpoints/checkpoint_1100000/state/check
point.
Traceback (most recent call last):
  File "/./bin/python_script/timesfm_pred.py", line 92, in <module>
    main()
  File "/./bin/python_script/timesfm_pred.py", line 72, in main
    tfm.load_from_checkpoint(checkpoint_path="/ssd1/cache/hpc_t0/hy/repo/timesfm_model/checkpoints/checkpoint
_1100000/state/checkpoint", step=1100000)
  File "/ssd1/cache/hpc_t0/hy/repo/timesfm/src/timesfm.py", line 270, in load_from_checkpoint
    self._train_state = checkpoints.restore_checkpoint(
  File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/paxml/checkpoints.py", line 227, in
 restore_checkpoint
    checkpoint_manager = checkpoint_managers.OrbaxCheckpointManager(
  File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/paxml/checkpoint_managers.py", line
 451, in __init__
    self._manager = _CheckpointManagerImpl(
  File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/paxml/checkpoint_managers.py", line
 290, in __init__
    step = self.any_step()
  File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/paxml/checkpoint_managers.py", line
 370, in any_step
    any_step = ocp.utils.any_checkpoint_step(self.directory)
  File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/orbax/checkpoint/utils.py", line 71
4, in any_checkpoint_step
    for s in checkpoint_dir.iterdir():
  File "/miniconda3/envs/tfm_env/lib/python3.10/site-packages/etils/epath/gpath.py", line 156, in
 iterdir
    for f in self._backend.listdir(self._path_str):
  File "/envs/tfm_env/lib/python3.10/site-packages/etils/epath/backend.py", line 142,
in listdir
    return [p for p in os.listdir(path) if not p.endswith('~')]
NotADirectoryError: [Errno 20] Not a directory: '/ssd1/cache/hpc_t0/hy/repo/timesfm_model/checkpoints/checkpo
int_1100000/state/checkpoint'

yanghou2000 avatar May 20 '24 16:05 yanghou2000