uni2ts icon indicating copy to clipboard operation
uni2ts copied to clipboard

[Feature Request] Support Loading Moirai Model from Local Checkpoint for Inference

Open wenzhaojie opened this issue 5 months ago • 1 comments

Is your feature request related to a problem? Please describe. I am frustrated that there is no standard or convenient way to load a Moirai model from a locally trained .ckpt checkpoint file for inference. The current official examples and documentation mainly show how to load models from the HuggingFace Hub using from_pretrained, but this fails when providing a local path (it raises a repo_id format error). Moreover, MoiraiModule does not have a load_from_checkpoint method like PyTorch Lightning, making local model restoration less user-friendly.

Describe the solution you'd like I would like an official and convenient way to load a Moirai model from a local checkpoint file for inference. For example, providing a method such as MoiraiForecast.load_from_checkpoint('path/to/last.ckpt') that works seamlessly for local files (similar to PyTorch Lightning or HuggingFace models).

Describe alternatives you've considered Currently, the code below is not work.

print(f"Loading local Moirai model from: {LOCAL_MODEL_CKPT}")
module = MoiraiModule.from_pretrained(
    pretrained_model_name_or_path=LOCAL_MODEL_CKPT,
    local_files_only=True,
)
model = MoiraiForecast(
    module=module,
    prediction_length=PDT,
    context_length=CTX,
    patch_size=PSZ,
    num_samples=100,
    target_dim=1,
    feat_dynamic_real_dim=ds.num_feat_dynamic_real,
    past_feat_dynamic_real_dim=ds.num_past_feat_dynamic_real,
)

predictor = model.create_predictor(batch_size=BSZ)
forecasts = predictor.predict(test_data.input)

However, this is not very convenient and requires careful tracking of all hyperparameters.

Additional context

  • A typical error when using from_pretrained with a local file:

    huggingface_hub.errors.HFValidationError: Repo id must be in the form 'repo_name' or 'namespace/repo_name': '.../last.ckpt'
    
  • My config (Hydra YAML) for the model looks like this:

    _target_: uni2ts.model.moirai.MoiraiPretrain
    module_kwargs:
      _target_: builtins.dict
      distr_output:
        _target_: uni2ts.distribution.MixtureOutput
        components:
          - _target_: uni2ts.distribution.StudentTOutput
          - _target_: uni2ts.distribution.NormalFixedScaleOutput
          - _target_: uni2ts.distribution.NegativeBinomialOutput
          - _target_: uni2ts.distribution.LogNormalOutput
      d_model: 384
      num_layers: 6
      patch_sizes: ${as_tuple:[8, 16, 32, 64, 128]}
      max_seq_len: 512
      attn_dropout_p: 0.0
      dropout_p: 0.0
      scaling: true
    
  • It would be very helpful for reproducibility and downstream deployment to have a simple, robust local checkpoint loading interface, without the need to exactly reconstruct all training parameters manually.

Thank you for considering this feature!

wenzhaojie avatar Jul 28 '25 04:07 wenzhaojie

Here's a snippet of how I did it my friend. Feel free to adapt it for your own setup.


module= MoiraiMoEModule(distr_output =uni2ts.distribution.MixtureOutput([uni2ts.distribution.StudentTOutput(), uni2ts.distribution.NormalFixedScaleOutput(),uni2ts.distribution.NegativeBinomialOutput(), uni2ts.distribution.LogNormalOutput()]),
d_model= 384, 
d_ff=512, 
num_layers = 6 ,
patch_sizes = (8, 16, 32, 64, 128),
max_seq_len = 512,
attn_dropout_p = 0.0, 
dropout_p =0.0,
scaling = True
)

model = MoiraiMoEForecast(
    module=module,
    prediction_length=PDT,
    context_length=CTX,
    patch_size=32,
    num_samples=100,
    target_dim=1,
    feat_dynamic_real_dim=0,
    past_feat_dynamic_real_dim=0,
)

ckpt_path="something_pytorch_ckpt.pt" 
state_dict = torch.load(ckpt_path)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

forecast = model(
    past_target=past_target.to(device),
    past_observed_target=past_observed_target.to(device),
    past_is_pad=past_is_pad.to(device),
)
print('forecast.shape: ', forecast.shape)

kasrayazdani avatar Jul 28 '25 22:07 kasrayazdani