darts icon indicating copy to clipboard operation
darts copied to clipboard

Feat/conformal prediction

Open dennisbader opened this issue 1 year ago • 2 comments

Checklist before merging this PR:

  • [x] Mentioned all issues that this PR fixes or addresses.
  • [x] Summarized the updates of this PR under Summary.
  • [ ] Added an entry under Unreleased in the Changelog.

Fixes #1704, fixes #2161.

Short Summary

  • Adds the first two Conformal Prediction Models: ConformalNaiveModel, and ConformalQRModel (read more below).
  • Adds 3 new quantile interval metrics (plus their aggregated versions):
    • Interval Winkler Score iws(), and Mean Interval Winkler Scores miws() (time-aggregated) (source)
    • Interval Coverage ic() (binary if observation is within the quantile interval), and Mean Interval Covarage mic() (time-aggregated)
    • Interval Non-Conformity Score for Quantile Regression incs_qr(), and Mean ... mincs_qr() (time-aggregated) (source)
  • Adds support for overlap_end=True in ForecastingModel.residuals(). This computes historical forecasts and residuals that can extend further than the end of the target series. With this, all returned residual values have the same length per forecast (the last residuals will contain missing values, if the forecasts extended further into the future than the end of the target series).

Summary

Adds first conformal prediction models to Darts. Conformal models can be applied to any of Darts' global forecasting model, as long as the model has been fitted before. In general the workflow of the models to produce one forecast/prediction is as follows:

  • Extract a calibration set: The number of calibration examples from the most recent past to use for one conformal prediction can be defined at model creation with parameter cal_length. To make your life simpler, we support two modes:
    • Automatic extraction of the calibration set from the past of your input series (series, past_covariates, ...). This is the default mode and our predict/forecasting/backtest/.... API is identical to any other forecasting model
    • Supply a fixed calibration set with parameters cal_series, cal_past_covariates, ... .
  • Generate historical forecasts on the calibration set (using the forecasting model)
  • Compute the errors/non-conformity scores (specific to each conformal model) on these historical forecasts
  • Compute the quantile values from the errors / non-conformity scores (using our desired quantiles set at model creation with parameter quantiles).
  • Compute the conformal prediction: Add the calibrated intervals to (or adjust the existing intervals of) the forecasting model's predictions.

Notes:

  • When computing historical_forecasts(), backtest(), residuals(), ... the above is applied for each forecast
  • For multi-horizon forecasts, the above is applied for each step in the horizon separately
  • Focus was put on keeping it as efficient as possible using mostly "vectorized" operations

Input Support

All added conformal models support the following input (depending on the fitted forecasting model):

  • uni/multivariate target series
  • past/future/static covariates
  • single/multiple series

Forecast/Output Support

All models support the following prediction modes:

  • single/multi-horizon forecasts. For multi-horizon, the calibration process is repeated per step in the forecast horizon.
  • single/mutliple quantile intervals: It can be any number of quantile intervals as long as they are centered around the median (e.g., quantiles=[0.05, 0.2, 0.5, 0.8, 0.95]).
  • historical forecasts with expanding or rolling calibration sets with parameter cal_length (to make the algorithm adaptive)
  • direct quantile predictions using predict_likelihood_parameters=True, num_samples=1 in all prediction methods.
  • sampled predictions from these quantile predictions using num_samples>>1 in all prediction methods.

Requirements to use a conformal model:

  • Any pre-trained GlobalForecastingModel (global baselines, all regression models, all torch models)
  • A long enough calibration set, depending on the forecast horizon n. It must be possible to generate at least n + cal_length historical forecasts from the calibration input series.

Added Algorithms

Added two algorithms each with two symmetry modes:

  • ConformalNaiveModel: Adds calibrated intervals around the median forecast from the forecasting model.
    • symmetric=True:
      • The lower and upper interval bounds are calibrated by the same magnitude.
      • Non-conformity scores: uses metric ae() (absolute error) to compute the non-conformity scores
    • symmetric=False
      • The lower and upper interval bounds are calibrated separately
      • Non-conformity scores: uses metric err() (error) to compute the non-conformity scores of the upper bounds, an -err() for the lower bounds.
  • ConformalQRModel (Conformalized Quantile Regression, source): Calibrates the quantile predictions from a probabilistic forecasting model.
    • symmetric=True:
      • The lower and upper interval bounds are calibrated by the same magnitude.
      • Non-conformity scores: uses metric incs_qr(symmetric=True) (Quantile Regression Non-Conformity Score) to compute the non-conformity scores
    • symmetric=False
      • The lower and upper interval bounds are calibrated separately
      • Non-conformity scores: uses metric incs_qr(symmetric=False) (Quantile Regression Non-Conformity Score) to compute the non-conformity scores for the upper and lower bound separately.

dennisbader avatar Oct 03 '24 13:10 dennisbader

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Codecov Report

Attention: Patch coverage is 97.60000% with 12 lines in your changes missing coverage. Please review.

Project coverage is 94.20%. Comparing base (412d983) to head (0b2c1cf). Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
darts/models/forecasting/conformal_models.py 98.49% 5 Missing :warning:
darts/utils/utils.py 89.13% 5 Missing :warning:
darts/utils/historical_forecasts/utils.py 96.07% 2 Missing :warning:
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2552      +/-   ##
==========================================
+ Coverage   94.15%   94.20%   +0.05%     
==========================================
  Files         139      140       +1     
  Lines       14992    15437     +445     
==========================================
+ Hits        14116    14543     +427     
- Misses        876      894      +18     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov[bot] avatar Oct 15 '24 16:10 codecov[bot]