etna icon indicating copy to clipboard operation
etna copied to clipboard

Add `predict` method to pipelines

Open iKintosh opened this issue 2 years ago • 0 comments

🚀 Feature Request

Add method predict to pipelines.

Motivation

Introduce more flexible ways of forecasting to our architecture.

Proposal

This task is blocked by #783.

Add a method to AbstractPipeline:

def predict(self, ts: Optional[TSDataset] = None, start_timestamp: Optional[pd.Timestamp] = None, end_timestamp: Optional[pd.Timestamp] = None, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975)) -> TSDataset:
    """
    Make predictions in a given range.

    Parameters
    ----------
    ts: 
        dataset with context, if isn't present ``self.ts`` is used
    start_timestamp:
        first timestamp of prediction range to return, should be >= than first timestamp in ``ts``;
        expected that beginning of each segment <= ``start_timestamp``
    end_timestamp:
        last timestamp of prediction range to return
    prediction_interval:
        If True returns prediction interval for forecast
    quantiles:
        Levels of prediction distribution. By default 2.5% and 97.5% taken to form a 95% prediction interval

    Returns
    -------
    :
        Dataset with predictions in ``[start_timestamp, end_timestamp]`` range.
    """

Details:

  1. ts doesn't contain all the features, transforms will be applied during predict
  2. If ts isn't set check presense of self.ts and try to use it.
  3. If different segments start with different timestamp we have guarantee to work only with start_timestamp >= beginning of all segments.
  4. If start_timestamp isn't set use first timestamp in ts.
  5. If end_timestamp isn't set use last timestamp in ts.
  6. Check that end_timestamp >= start_timestamp.
  7. If model fails during prediction (e.g. because of NaNs) we should let predict fail because of it.

Add implementation for BasePipeline. It should:

  1. Check constraint on ts and start_timestamp.
  2. Check start_timestamp, end_timestamp values.
  3. if prediction_interval=True check that this pipeline supports prediction interval and validate quantiles.
  4. call _predict method for getting actual predictions.

Pipelines that support prediction intervals:

  1. Pipeline with models that support prediction intervals.
  2. Other cases can be added in the future.

Add empty implementation of BasePipeline._predict, we will add it in the future.

Test cases

For all concrete pipelines add tests:

  1. Check that predict works correctly if ts isn't set and self.ts is present.
  2. Check that predict fails if ts isn't set and self.ts isn't present.
  3. Check that predict fails if the constraint between ts and start_timestamp is broken.
  4. Check that predict works correctly if start_timestamp or end_timestamp isn't set.
  5. Check that predict fails if end_timestamp < start_timestamp.
  6. Check that _predict is called.
  7. Check that if prediction_interval=True, checking pipeline support of intervals is called.
  8. Check that if prediction_interval=True, function for quantiles validation is called.

Alternatives

No response

Additional context

No response

Checklist

  • [x] I discussed this issue with ETNA Team

iKintosh avatar May 31 '22 11:05 iKintosh