neural_prophet icon indicating copy to clipboard operation
neural_prophet copied to clipboard

add model.get_params and set_params

Open ourownstory opened this issue 4 years ago • 1 comments

similar to scikit model.get_params

for scenario testing or for parameter-based save / load functions

ourownstory avatar Jan 20 '21 23:01 ourownstory

see discussion #238

ourownstory avatar Jan 21 '21 00:01 ourownstory

Hello @ourownstory, Can you assign this issue to me?

anastasiia-tovstolis avatar Oct 14 '22 16:10 anastasiia-tovstolis

Hi @ourownstory, @anastasiia-tovstolis, with our ongoing migration to Pytorch Lightning (and I think also in plain pytorch) you can access the model parameters using m.model.state_dict(). This yields the trained parameters such as bias, trend etc. Further you can use load_state_dict() to set the parameters of the model. Not sure if this is exactly what you are looking for, but it seems very related.

The output looks roughly as follows

{
    'bias': tensor([0.7926])
    'trend_k0': tensor([-0.3043])
    'trend_deltas': tensor([[1.6712, 0.7629, 1.5933, 0.4826, 1.2354, 1.0626, 1.3238, 1.1458, 1.6141, 0.1156, 1.1982]])
    'ar_net.0.weight': tensor([[1.8130, 1.6902, 0.5619], [0.3293, 0.8004, 0.9695]])
    'covar_nets.A.0.weight': tensor([[0.7283, 1.5743, 0.9238]
}

karl-richter avatar Oct 16 '22 21:10 karl-richter

Hi @ourownstory, I would like to clarify if this issue is still relevant in connection with the comment @karl-richter?

anastasiia-tovstolis avatar Oct 17 '22 07:10 anastasiia-tovstolis

@anastasiia-tovstolis thanks for your contribution and great question. @Kevin-Chen0 also just brought up the topic in issue #821.

To bring all those discussion points together my proposal to moving forward would be:

  1. We focus purely on getting parameters (remove the setter part, as this would introduce another redundant API on how to configure the model, the current one is good - or please speak up if I'm missing any good arguments for setters)
  2. Find a clear naming for user-defined parameters (hyperparameters or config/configuration) and fitted weights learned by the model (named state by pytorch, possibly parameters, yet confusing with hyperparameters) - would be glad to have your input here
  3. Implement two functions which expose the user-defined configuration and fitted weights of our model to the user with a clear API - the second one would be kind of a wrapper for the pytorch function @karl-richter mentioned earlier - possibly we separate this also into two pull requests (one for each function).

Happy to hear your thoughts @anastasiia-tovstolis @ourownstory @Kevin-Chen0 and @karl-richter - hope we can move forward here quickly the next days :)

noxan avatar Nov 30 '22 02:11 noxan

Also that we do not forget @Kevin-Chen0's contributions on the other issue

Add m.summary() (from Keras), m.state_dict() (from PyTorch), and/or m.parameters() (also from PyTorch) methods into the NeuralProphet module.

This makes it easier to show and port the NeuralProphet parameters, including for seasonality, lags, changepoints, and others.

Create m.state_dict() for display the NeuralProphet hyperparameters in the dictionary format. Can also be called n.hyperparameters().

noxan avatar Nov 30 '22 02:11 noxan

Yesterday I finished these steps, I hope it's well done) @noxan @ourownstory @Kevin-Chen0 @karl-richter

anastasiia-tovstolis avatar Dec 19 '22 18:12 anastasiia-tovstolis

Closing thanks to the pull request of @anastasiia-tovstolis

noxan avatar Feb 24 '23 05:02 noxan