[BUG] plot_prediction_actual_by_variable mixes up variables for time_varying_known_categoricals
Describe the bug
I've noticed that the plot labels under the plot_prediction_actual_by_variable time_varying_known_categoricals do not correspond to the correct axis label. I made a simple example where a value increases each day throughout the week then verified that the labeled actuals are not correct.
To Reproduce
I define a simple time-series were sales are a noise distribution scaled buy the day of the week number monday=0...sunday=6. Then I make a TFT prediction then evaluate the actual vs prediction plot.
Actual vs prediction plot:
Actual actuals:
# Define dataset
# Generate synthetic dataset
np.random.seed(42)
num_days = 300 # Number of days of data
hours_per_day = 24
# Create hourly timestamps
timestamps = pd.date_range(
start="2024-01-01", periods=num_days * hours_per_day, freq="h"
)
# Simulate hourly order counts with Poisson distribution scaled by some time varying function.
hourly_orders = np.random.poisson(
lam=(
20 * (timestamps.hour / 23) * (timestamps.day / 31) * (timestamps.dayofweek / 6)
),
size=len(timestamps),
)
# Aggregate to get daily total sales (sum of hourly orders per day)
daily_sales = np.add.reduceat(
hourly_orders, np.arange(0, len(hourly_orders), hours_per_day)
)
# Expand daily sales to hourly timestamps (each day has the same total sales)
daily_sales_expanded = np.repeat(daily_sales, hours_per_day)
# Create dataframe
df = pd.DataFrame(
{
"time_idx": np.arange(len(timestamps)), # Continuous index
"datetime": timestamps,
"group_id": 0, # Single series, use a fixed group ID
"orders_per_hour": hourly_orders, # Input feature
"total_sales": daily_sales_expanded,
}
)
# add some features for the model to learn
# in this case the sales function is just noise + hour + day of month
# with those features it should fit very well
df["hour"] = df["datetime"].dt.hour
df["day_of_week"] = df["datetime"].dt.weekday
df["day_of_month"] = df["datetime"].dt.day
df["date"] = df["datetime"].dt.date
df["month_name"] = df["datetime"].dt.month_name().astype(str).astype("category")
df["day_name"] = df["datetime"].dt.day_name().astype(str).astype("category")
# Remove last day's incomplete target
df = df[:-hours_per_day]
# Convert numeric columns to float32
df["orders_per_hour"] = df["orders_per_hour"].astype(np.float32)
df["total_sales"] = df["total_sales"].astype(np.float32)
training_cutoff = (
df["time_idx"].max() - max_prediction_length
) # Last part reserved for validation
training = TimeSeriesDataSet(
df[lambda x: x.time_idx <= training_cutoff],
time_idx="time_idx",
target="total_sales", # Multi-target: ["orders_per_hour", "total_sales",]
group_ids=["group_id"], # Single time series
min_encoder_length=max_encoder_length,
max_encoder_length=max_encoder_length,
min_prediction_length=max_prediction_length,
max_prediction_length=max_prediction_length,
static_categoricals=[],
time_varying_known_categoricals=["day_name", "month_name"],
time_varying_known_reals=[
"hour",
"day_of_week",
],
time_varying_unknown_reals=[
"orders_per_hour",
"total_sales",
],
target_normalizer=GroupNormalizer(groups=[], transformation="softplus"),
)
validation = TimeSeriesDataSet.from_dataset(
training, df, predict=True, stop_randomization=True
)
# Convert to PyTorch dataloaders
batch_size = min(32, len(df) // (max_encoder_length + max_prediction_length))
train_dataloader = training.to_dataloader(
train=True, batch_size=batch_size, num_workers=0
)
val_dataloader = validation.to_dataloader(
train=False, batch_size=batch_size, num_workers=0
)
# Define TFT model
tft = TemporalFusionTransformer.from_dataset(
training,
learning_rate=0.001,
hidden_size=16,
attention_head_size=4,
dropout=0.1,
hidden_continuous_size=8,
output_size=1,
loss=MAE(),
log_interval=10,
reduce_on_plateau_patience=4,
optimizer="adam",
)
# Train TFT model
trainer = Trainer(max_epochs=1, accelerator="cpu")
trainer.fit(tft, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
predictions = tft.predict(
val_dataloader, return_x=True, trainer_kwargs=dict(accelerator="cpu")
)
predictions_vs_actuals = tft.calculate_prediction_actual_by_variable(
predictions.x, predictions.output
)
tft.plot_prediction_actual_by_variable(predictions_vs_actuals, name="day_name")
Expected behavior
The labels in the plot should follow the days of the week: Monday, Tuesday, Wednesday, Thursday, Friday, Saturday, Sunday
Versions
pytorch_forecasting = "1.3.0"
The issue is that plot_prediction_actual_by_variable assumes self.hparams.embedding_labels to be in the same order as the values in averages_actual_cat but that's not the case.
here: https://github.com/sktime/pytorch-forecasting/blob/1a2d83c7a5e6769c13164eeae7f447002f61f254/pytorch_forecasting/models/base/_base_model.py#L2280
Does anyone know why this should be the case? The order of averages_actual_cat is the result of a groupby_apply over the categories in the input data.
https://github.com/sktime/pytorch-forecasting/blob/1a2d83c7a5e6769c13164eeae7f447002f61f254/pytorch_forecasting/models/base/_base_model.py#L2109