neuralforecast
neuralforecast copied to clipboard
[FIX] Unify API
This is a large refactoring PR and open for discussion. The main goal of the PR is to unify API across different model types, and unify loss functions across different loss types.
Refactoring:
- Fuses
BaseWindows
,BaseMultivariate
andBaseRecurrent
intoBaseModel
, removing the need for separate classes and unifying model API across different model types. Instead, this PR introduces two model attributes, yielding four possible model options:RECURRENT
(True
/False
) andMULTIVARIATE
(True
/False
). We currently have a model for every combination except a recurrent multivariate model (e.g. a multivariate LSTM), however this is now relatively simple to add. In addition, this change allows to have models that can be recurrent or not, or multivariate or not on-the-fly, based on users' input. This also allows for easier modelling going forward. - Unifies model API across all models, adding missing input variables to all model types.
- Refactors losses, a.o. removing unnecessary
domain_map
functions. - Moves
loss.domain_map
outside of models toBaseModel
- Moves RevINMultivariate used by
TSMixer
,TSMixerx
andRMoK
tocommon.modules
Features:
- All losses compatible with all types of models (e.g. univariate/multivariate, direct/recurrent) OR appropriate protection added.
-
DistributionLoss
now supports the use ofquantiles
inpredict
, allowing for easy quantile retrieval for all DistributionLosses. - Mixture losses (
GMM
,PMM
andNBMM
) now support learned weights for weighted mixture distribution outputs. - Mixture losses now support the use of
quantiles
inpredict
, allowing for easy quantile retrieval. - Improved stability of
ISQF
by adding softplus protection around some parameters instead of using.abs
- Unified API for any quantile or any confidence level during predict for both point- and distribution losses.
Bug fixes:
-
MASE
loss now works. - Added various protections around parameter combinations that are invalid (e.g. regarding losses).
-
StudentT
increase default DoF to 3 to reduce unbound variance issues. - All models are now tested using a test function on the AirPassengers dataset; in most models we included
eval: false
on the examples whilst not having any other tests, causing most models to effectively not being tested at all. - IQLoss doesn't give monotonic quantiles, now it does (by quantiling the quantiles)
- When training with both a conformal method and non-conformal method, the latter is also cross-validated to compute conformity scores. This is a redundant training step, and removed in this PR.
Breaking changes:
- Rewrite of all recurrent models to get rid of the quadratic (in the sequence dimension) space complexity. As a result, it is impossible to load a recurrent model from a previous version into this version.
- Recurrent models now require an
input_size
to be given. -
TCN
andDRNN
are now windows models, not recurrent models.
Tests:
- Added
common._model_checks.py
that includes a model testing function. This function runs on every separate model, ensuring that every model is tested on push.
Todo:
- [x] Test models on speed/scaling as compared to current implementation across a set of datasets.
- [x] Make sure docstring of all multivariate models is updated to reflect the additional inputs