[Feature Request] Support Loading Moirai Model from Local Checkpoint for Inference
Is your feature request related to a problem? Please describe.
I am frustrated that there is no standard or convenient way to load a Moirai model from a locally trained .ckpt checkpoint file for inference.
The current official examples and documentation mainly show how to load models from the HuggingFace Hub using from_pretrained, but this fails when providing a local path (it raises a repo_id format error).
Moreover, MoiraiModule does not have a load_from_checkpoint method like PyTorch Lightning, making local model restoration less user-friendly.
Describe the solution you'd like
I would like an official and convenient way to load a Moirai model from a local checkpoint file for inference.
For example, providing a method such as MoiraiForecast.load_from_checkpoint('path/to/last.ckpt') that works seamlessly for local files (similar to PyTorch Lightning or HuggingFace models).
Describe alternatives you've considered Currently, the code below is not work.
print(f"Loading local Moirai model from: {LOCAL_MODEL_CKPT}")
module = MoiraiModule.from_pretrained(
pretrained_model_name_or_path=LOCAL_MODEL_CKPT,
local_files_only=True,
)
model = MoiraiForecast(
module=module,
prediction_length=PDT,
context_length=CTX,
patch_size=PSZ,
num_samples=100,
target_dim=1,
feat_dynamic_real_dim=ds.num_feat_dynamic_real,
past_feat_dynamic_real_dim=ds.num_past_feat_dynamic_real,
)
predictor = model.create_predictor(batch_size=BSZ)
forecasts = predictor.predict(test_data.input)
However, this is not very convenient and requires careful tracking of all hyperparameters.
Additional context
-
A typical error when using
from_pretrainedwith a local file:huggingface_hub.errors.HFValidationError: Repo id must be in the form 'repo_name' or 'namespace/repo_name': '.../last.ckpt' -
My config (Hydra YAML) for the model looks like this:
_target_: uni2ts.model.moirai.MoiraiPretrain module_kwargs: _target_: builtins.dict distr_output: _target_: uni2ts.distribution.MixtureOutput components: - _target_: uni2ts.distribution.StudentTOutput - _target_: uni2ts.distribution.NormalFixedScaleOutput - _target_: uni2ts.distribution.NegativeBinomialOutput - _target_: uni2ts.distribution.LogNormalOutput d_model: 384 num_layers: 6 patch_sizes: ${as_tuple:[8, 16, 32, 64, 128]} max_seq_len: 512 attn_dropout_p: 0.0 dropout_p: 0.0 scaling: true -
It would be very helpful for reproducibility and downstream deployment to have a simple, robust local checkpoint loading interface, without the need to exactly reconstruct all training parameters manually.
Thank you for considering this feature!
Here's a snippet of how I did it my friend. Feel free to adapt it for your own setup.
module= MoiraiMoEModule(distr_output =uni2ts.distribution.MixtureOutput([uni2ts.distribution.StudentTOutput(), uni2ts.distribution.NormalFixedScaleOutput(),uni2ts.distribution.NegativeBinomialOutput(), uni2ts.distribution.LogNormalOutput()]),
d_model= 384,
d_ff=512,
num_layers = 6 ,
patch_sizes = (8, 16, 32, 64, 128),
max_seq_len = 512,
attn_dropout_p = 0.0,
dropout_p =0.0,
scaling = True
)
model = MoiraiMoEForecast(
module=module,
prediction_length=PDT,
context_length=CTX,
patch_size=32,
num_samples=100,
target_dim=1,
feat_dynamic_real_dim=0,
past_feat_dynamic_real_dim=0,
)
ckpt_path="something_pytorch_ckpt.pt"
state_dict = torch.load(ckpt_path)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
forecast = model(
past_target=past_target.to(device),
past_observed_target=past_observed_target.to(device),
past_is_pad=past_is_pad.to(device),
)
print('forecast.shape: ', forecast.shape)