neural_prophet
neural_prophet copied to clipboard
add model.get_params and set_params
similar to scikit model.get_params
for scenario testing or for parameter-based save / load functions
see discussion #238
Hello @ourownstory, Can you assign this issue to me?
Hi @ourownstory, @anastasiia-tovstolis, with our ongoing migration to Pytorch Lightning (and I think also in plain pytorch) you can access the model parameters using m.model.state_dict()
. This yields the trained parameters such as bias, trend etc. Further you can use load_state_dict()
to set the parameters of the model. Not sure if this is exactly what you are looking for, but it seems very related.
The output looks roughly as follows
{
'bias': tensor([0.7926])
'trend_k0': tensor([-0.3043])
'trend_deltas': tensor([[1.6712, 0.7629, 1.5933, 0.4826, 1.2354, 1.0626, 1.3238, 1.1458, 1.6141, 0.1156, 1.1982]])
'ar_net.0.weight': tensor([[1.8130, 1.6902, 0.5619], [0.3293, 0.8004, 0.9695]])
'covar_nets.A.0.weight': tensor([[0.7283, 1.5743, 0.9238]
}
Hi @ourownstory, I would like to clarify if this issue is still relevant in connection with the comment @karl-richter?
@anastasiia-tovstolis thanks for your contribution and great question. @Kevin-Chen0 also just brought up the topic in issue #821.
To bring all those discussion points together my proposal to moving forward would be:
- We focus purely on getting parameters (remove the setter part, as this would introduce another redundant API on how to configure the model, the current one is good - or please speak up if I'm missing any good arguments for setters)
- Find a clear naming for user-defined parameters (
hyperparameters
orconfig
/configuration
) and fitted weights learned by the model (namedstate
by pytorch, possiblyparameters
, yet confusing with hyperparameters) - would be glad to have your input here - Implement two functions which expose the user-defined configuration and fitted weights of our model to the user with a clear API - the second one would be kind of a wrapper for the pytorch function @karl-richter mentioned earlier - possibly we separate this also into two pull requests (one for each function).
Happy to hear your thoughts @anastasiia-tovstolis @ourownstory @Kevin-Chen0 and @karl-richter - hope we can move forward here quickly the next days :)
Also that we do not forget @Kevin-Chen0's contributions on the other issue
Add
m.summary()
(from Keras),m.state_dict()
(from PyTorch), and/orm.parameters()
(also from PyTorch) methods into the NeuralProphet module.This makes it easier to show and port the NeuralProphet parameters, including for seasonality, lags, changepoints, and others.
Create
m.state_dict()
for display the NeuralProphet hyperparameters in the dictionary format. Can also be calledn.hyperparameters()
.
Yesterday I finished these steps, I hope it's well done) @noxan @ourownstory @Kevin-Chen0 @karl-richter
Closing thanks to the pull request of @anastasiia-tovstolis