pytorch-forecasting icon indicating copy to clipboard operation
pytorch-forecasting copied to clipboard

How to save a Pytorch-Forecasting model after training it

Open c3-varun opened this issue 1 year ago • 3 comments

  • PyTorch-Forecasting version: pytorch_forecasting-0.10.3
  • PyTorch version: torch-1.13.1
  • Python version: 3.8.13
  • Operating System: Red Hat Enterprise Linux 8.6 (Ootpa)

Other libraries: PrettyTable-3.6.0 autopage-0.5.1 cliff-4.2.0 cmaes-0.9.1 cmd2-2.4.3 colorlog-6.7.0 optuna-2.10.1 pandas-1.5.3 pbr-5.11.1 pyperclip-1.8.2 pytorch-lightning-1.9.4 scikit-learn-1.1.3 scipy-1.10.1 stevedore-5.0.0

Expected behavior

I executed code torch.save(tft, 'Baselining-48-720-1-720.pth') in order to save my model along with its weights and I expected the file to save.

I was able to save weights using torch.save(tft.state_dict(), 'Baselining-48-720-1-720.pth'), but that doesn't save the network.

Actual behavior

I got the following error:

AttributeError Traceback (most recent call last) in ----> 1 torch.save(tft, 'Baselining-48-720-1-720.pth')

~/.local/lib/python3.8/site-packages/torch/serialization.py in save(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization) 421 if _use_new_zipfile_serialization: 422 with _open_zipfile_writer(f) as opened_zipfile: --> 423 _save(obj, opened_zipfile, pickle_module, pickle_protocol) 424 return 425 else:

~/.local/lib/python3.8/site-packages/torch/serialization.py in _save(obj, zip_file, pickle_module, pickle_protocol) 633 pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol) 634 pickler.persistent_id = persistent_id --> 635 pickler.dump(obj) 636 data_value = data_buf.getvalue() 637 zip_file.write_record('data.pkl', data_value, len(data_value))

AttributeError: Can't pickle local object 'TupleOutputMixIn.to_network_output..Output'

Code to reproduce the problem

My model is a TFT and I trained it following this tutorial (https://pytorch-forecasting.readthedocs.io/en/stable/tutorials/stallion.html):

tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,
    hidden_size=16,
    attention_head_size=1,
    dropout=0.1,
    hidden_continuous_size=8,
    output_size=len(quantiles),
    loss=QuantileLoss(quantiles=quantiles),
    log_interval=10,  # uncomment for learning rate finder and otherwise, e.g. to 10 for logging every 10 batches
    reduce_on_plateau_patience=4,
)

The following line failed:

torch.save(tft, 'Baselining-48-720-1-720.pth')

c3-varun avatar Apr 04 '23 21:04 c3-varun

I am using a windows machine and for saving the model tft as a pickle file worked. I was able to load the model and then run the .predict method. The required imports should be present in the environment.

I saved the model after the we get the best model i.e.

best_model_path = trainer.checkpoint_callback.best_model_path
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)

Once we have the best Model,we can save the network and the weights using pickle

import pickle

with open("tft.pkl",'wb') as f:
    pickle.dump(best_tft)

To load the model

import pickle

with open("tft.pkl",'rb') as f:
    model=pickle.load(f)

GhoulMac avatar Oct 17 '23 11:10 GhoulMac

I'm currently using pytorch-forecasting 1.0.0 and have the same problem when trying to pickle a model like TemporalFusionTransformer.

The problem seems to be that one of its super classes TupleOutputMixIn has a local method in its to_network_output() function.

ivanightingale avatar Feb 01 '24 21:02 ivanightingale

Anybody found a solution for this? Currently encountering the same issue when trying to write output to disk on predict batch end for a Temporal Fusion Transformer.

bhecker avatar Jul 12 '24 17:07 bhecker