neuralforecast
neuralforecast copied to clipboard
TypeError: Module.load_state_dict() got an unexpected keyword argument 'assign'
What happened + What you expected to happen
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /Users/leo/web3/LLM/langchain/mlts/nf_iTransformer.py:47 in
Versions / Dependencies
Name: neuralforecast Version: 1.7.1 Summary: Time series forecasting suite using deep learning models Home-page: https://github.com/Nixtla/neuralforecast/ Author: Nixtla Author-email: [email protected] License: Apache Software License 2.0
Reproduction script
nf = NeuralForecast.load(path='./checkpoints/test_run/')
Issue Severity
None
Hi, do you have a piece of standalone code that I can run to reproduce this error? That would help me debug.
From the limited information it seems maybe the checkpoint you are loading is of the wrong datatype, or possibly it's a version issue with your Pytorch installation (i.e. the checkpoint was saved with a different version than Nixtla is using). But this is a bit guessing :)
Hi, do you have a piece of standalone code that I can run to reproduce this error? That would help me debug.
From the limited information it seems maybe the checkpoint you are loading is of the wrong datatype, or possibly it's a version issue with your Pytorch installation (i.e. the checkpoint was saved with a different version than Nixtla is using). But this is a bit guessing :)
from neuralforecast.auto import AutoTSMixer, AutoTSMixerx from ray.tune.search.hyperopt import HyperOptSearch from ray import tune from neuralforecast.losses.numpy import mse, mae import matplotlib.pyplot as plt import pandas as pd
from datasetsforecast.long_horizon import LongHorizon from neuralforecast.core import NeuralForecast from neuralforecast.models import TSMixer, TSMixerx, NHITS, MLPMultivariate, iTransformer from neuralforecast.losses.pytorch import MSE, MAE
Change this to your own data to try the model
Y_df, X_df, _ = LongHorizon.load(directory='./', group='ETTm2') Y_df['ds'] = pd.to_datetime(Y_df['ds'])
X_df contains the exogenous features, which we add to Y_df
X_df['ds'] = pd.to_datetime(X_df['ds']) Y_df = Y_df.merge(X_df, on=['unique_id', 'ds'], how='left')
We make validation and test splits
n_time = len(Y_df.ds.unique()) val_size = int(.2 * n_time) test_size = int(.2 * n_time) horizon = 96 input_size = 512 models = [
TSMixerx(h=horizon,
input_size=input_size,
n_series=7,
max_steps=10,
val_check_steps=10,
early_stop_patience_steps=5,
scaler_type='identity',
dropout=0.7,
valid_loss=MAE(),
random_seed=12345678,
futr_exog_list=['ex_1', 'ex_2', 'ex_3', 'ex_4'],
),
] nf = NeuralForecast( models=models, freq='15min')
Y_hat_df = nf.cross_validation(df=Y_df, val_size=val_size, test_size=test_size, n_windows=None ) nf.save(path='./checkpoints/test_run/', model_index=None, overwrite=True, save_dataset=True) nf = NeuralForecast.load(path='./checkpoints/test_run/')
Y_hat_df = Y_hat_df.reset_index()
for model in models: mae_model = mae(Y_hat_df['y'], Y_hat_df[f'{model}']) mse_model = mse(Y_hat_df['y'], Y_hat_df[f'{model}']) print(f'{model} horizon {horizon} - MAE: {mae_model:.3f}') print(f'{model} horizon {horizon} - MSE: {mse_model:.3f}')
Thanks - I have zero issues executing that code. So my response is similar to #987, i.e.
Can you give more details about the machine config (OS, Python) you are using? How are you running this script?
If I'd have to guess it's a package conflict issue - so I would create a new virtual environment, install neuralforecast in that environment, and try rerunning the script.
@LeonTing1010 hi the environment-cpu.yml write the pytorch should >=2.0.0 but in 2.0.0 and 2.0.1 the code in https://github.com/pytorch/pytorch/blame/v2.0.0/torch/nn/modules/module.py#L1969
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
has no keyword argument 'assign' so you should chenge the requirement from pytorch>=2.0.0 to pytorch>=2.1.0
This issue has been automatically closed because it has been awaiting a response for too long. When you have time to to work with the maintainers to resolve this issue, please post a new comment and it will be re-opened. If the issue has been locked for editing by the time you return to it, please open a new issue and reference this one.
Reopening to remove the argument in versions that don't support it.