gluonts icon indicating copy to clipboard operation
gluonts copied to clipboard

Median prediction is plotted one index shift to the left

Open zhichenggeng opened this issue 2 years ago • 4 comments

Description

The plot of median prediction is shifted to the left by one index, while 90% and 50% prediction interval plots are correct. Potential reason is that the data is resampled by week, which might cause some problem on the index.

To Reproduce

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


def generate_single_ts(date_range, item_id=None) -> pd.DataFrame:
    """create sum of `n_f` sin/cos curves with random scale and phase."""
    n_f = 2
    period = np.array([24 / (i + 1) for i in range(n_f)]).reshape(1, n_f)
    scale = np.random.normal(1, 0.3, size=(1, n_f))
    phase = 2 * np.pi * np.random.uniform(size=(1, n_f))
    periodic_f = lambda x: scale * np.sin(np.pi * x / period + phase)

    t = np.arange(0, len(date_range)).reshape(-1, 1)
    target = periodic_f(t).sum(axis=1) + np.random.normal(0, 0.1, size=len(t))
    ts = pd.DataFrame({"target": target}, index=date_range)
    if item_id is not None:
        ts["item_id"] = item_id
    return ts

prediction_length, freq = 2, "1D"
T = 365 * prediction_length
date_range = pd.date_range("2021-01-01", periods=T, freq=freq)
ts = generate_single_ts(date_range)

ts = ts.resample('W', label='left').sum()

from gluonts.model.deepar import DeepAREstimator
from gluonts.mx import Trainer
from gluonts.evaluation import make_evaluation_predictions, Evaluator

prediction_length, freq = 4, "W"
estimator = DeepAREstimator(
    freq=freq, prediction_length=prediction_length, trainer=Trainer(epochs=1)
)

from gluonts.dataset.pandas import PandasDataset

train = PandasDataset(ts[:-prediction_length], target="target", freq=freq)
test = PandasDataset(ts, target="target", freq=freq)

predictor = estimator.train(train)
forecast_it, ts_it = make_evaluation_predictions(dataset=test, predictor=predictor)

forecasts = list(forecast_it)
tss = list(ts_it)
forecast_entry = forecasts[0]
ts_entry = tss[0]

fig, ax = plt.subplots(1, 1, figsize=(20, 8))
ts_entry[-prediction_length * 4:].plot(ax=ax)
forecast_entry.plot(prediction_intervals=(50, 90), color="g")

Error message or code output

Environment

  • Operating system: Amazon Linux
  • Python version: 3.8.12
  • GluonTS version: 0.10.2
  • MXNet version: 1.9.1

zhichenggeng avatar Aug 09 '22 17:08 zhichenggeng

Thanks for the reproducible example!

This seems to be an issue when mixing Period and Timestamp columns. If I change the plotting of the true values, it works:

ts_entry[-prediction_length * 4:].to_timestamp().plot(ax=ax)
Screenshot 2022-08-09 at 20 23 36

I don't know (yet) why this happens -- but at least there is a workaround.

I was thinking about having some more plot utilities to a) make generating these kind of plots easier and b) avoid issues like this one.

jaheba avatar Aug 09 '22 18:08 jaheba

It works! Thanks for your help.

Looking forward to more generalized plotting tools.

zhichenggeng avatar Aug 09 '22 18:08 zhichenggeng

@jaheba as a fix, would it make sense to bake .to_timestamp in any utils (hopefully it’s a single place) we use to turn DataEntry into pandas?

lostella avatar Aug 24 '22 07:08 lostella

What's interesting is that inverting the order of plotting also solves the issue:

forecast_entry.plot(prediction_intervals=(50, 90), color="g")
ts_entry[-prediction_length * 4:].plot(ax=ax)

image

lostella avatar Aug 24 '22 12:08 lostella