gluonts
gluonts copied to clipboard
Median prediction is plotted one index shift to the left
Description
The plot of median prediction is shifted to the left by one index, while 90% and 50% prediction interval plots are correct. Potential reason is that the data is resampled by week, which might cause some problem on the index.
To Reproduce
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
def generate_single_ts(date_range, item_id=None) -> pd.DataFrame:
"""create sum of `n_f` sin/cos curves with random scale and phase."""
n_f = 2
period = np.array([24 / (i + 1) for i in range(n_f)]).reshape(1, n_f)
scale = np.random.normal(1, 0.3, size=(1, n_f))
phase = 2 * np.pi * np.random.uniform(size=(1, n_f))
periodic_f = lambda x: scale * np.sin(np.pi * x / period + phase)
t = np.arange(0, len(date_range)).reshape(-1, 1)
target = periodic_f(t).sum(axis=1) + np.random.normal(0, 0.1, size=len(t))
ts = pd.DataFrame({"target": target}, index=date_range)
if item_id is not None:
ts["item_id"] = item_id
return ts
prediction_length, freq = 2, "1D"
T = 365 * prediction_length
date_range = pd.date_range("2021-01-01", periods=T, freq=freq)
ts = generate_single_ts(date_range)
ts = ts.resample('W', label='left').sum()
from gluonts.model.deepar import DeepAREstimator
from gluonts.mx import Trainer
from gluonts.evaluation import make_evaluation_predictions, Evaluator
prediction_length, freq = 4, "W"
estimator = DeepAREstimator(
freq=freq, prediction_length=prediction_length, trainer=Trainer(epochs=1)
)
from gluonts.dataset.pandas import PandasDataset
train = PandasDataset(ts[:-prediction_length], target="target", freq=freq)
test = PandasDataset(ts, target="target", freq=freq)
predictor = estimator.train(train)
forecast_it, ts_it = make_evaluation_predictions(dataset=test, predictor=predictor)
forecasts = list(forecast_it)
tss = list(ts_it)
forecast_entry = forecasts[0]
ts_entry = tss[0]
fig, ax = plt.subplots(1, 1, figsize=(20, 8))
ts_entry[-prediction_length * 4:].plot(ax=ax)
forecast_entry.plot(prediction_intervals=(50, 90), color="g")
Error message or code output
Environment
- Operating system: Amazon Linux
- Python version: 3.8.12
- GluonTS version: 0.10.2
- MXNet version: 1.9.1
Thanks for the reproducible example!
This seems to be an issue when mixing Period
and Timestamp
columns. If I change the plotting of the true values, it works:
ts_entry[-prediction_length * 4:].to_timestamp().plot(ax=ax)

I don't know (yet) why this happens -- but at least there is a workaround.
I was thinking about having some more plot utilities to a) make generating these kind of plots easier and b) avoid issues like this one.
It works! Thanks for your help.
Looking forward to more generalized plotting tools.
@jaheba as a fix, would it make sense to bake .to_timestamp
in any utils (hopefully it’s a single place) we use to turn DataEntry
into pandas?
What's interesting is that inverting the order of plotting also solves the issue:
forecast_entry.plot(prediction_intervals=(50, 90), color="g")
ts_entry[-prediction_length * 4:].plot(ax=ax)