pytorch-forecasting
pytorch-forecasting copied to clipboard
Data leakage problem
!pip install lightning !pip install pytorch-forecasting !pip install torch !pip install optuna==3.4 !pip install torch==2.0.1
max_prediction_length = 6
max_encoder_length = 12
training_cutoff = df["time_idx"].max() - max_prediction_length
training =TimeSeriesDataSet(
df[lambda x: x.time_idx <= training_cutoff],
time_idx="time_idx",
target="total_sales",
group_ids=['group_ids'],
max_encoder_length=max_encoder_length,
max_prediction_length=max_prediction_length,
static_categoricals=["group_ids"],
time_varying_known_reals=['retail_food_sales_vol_bil_ch_usd__Month_MA2',
'retail_sales_food_services_bil_usd',
'change_private_inv_minus_retail_trade_bil_ch_usd__Month_MA2',
'month', 'year', 'day_of_week', 'quarter', '12th_month'],
time_varying_unknown_reals=["total_sales"],
predict_mode=False,
allow_missing_timesteps=True
)
validation = TimeSeriesDataSet.from_dataset(training, df, predict=True, stop_randomization=True)
batch_size = 10
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0)
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min")
lr_logger = LearningRateMonitor() # log the learning rate
logger = TensorBoardLogger("lightning_logs") # logging results to a tensorboard
trainer = pl.Trainer(
max_epochs=10,
accelerator="cpu",
enable_model_summary=True,
gradient_clip_val=0.1,
limit_train_batches=50, # coment in for training, running valiation every 30 batches
# fast_dev_run=True, # comment in to check that networkor dataset has no serious bugs
callbacks=[lr_logger, early_stop_callback],
logger=logger,
)
tft = TemporalFusionTransformer.from_dataset(
training,
learning_rate=0.03,
hidden_size=16,
attention_head_size=2,
dropout=0.1,
hidden_continuous_size=8,
loss=QuantileLoss(),
log_interval=10, # uncomment for learning rate finder and otherwise, e.g. to 10 for logging every 10 batches
optimizer="Ranger",
reduce_on_plateau_patience=4,
)
print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")
trainer.fit(
tft,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
)
best_model_path = trainer.checkpoint_callback.best_model_path
print(best_model_path)
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)
predictions = best_tft.predict(val_dataloader, return_y=True, trainer_kwargs=dict(accelerator="cpu"))
MAE()(predictions.output, predictions.y)
**This code is working fine but when i tried to forecast on unseen data below it is giving error **
params = training.get_parameters()
inference_ts_dataset = TimeSeriesDataSet.from_parameters(data= filtered_base_data, parameters=params, predict=True)
predictions = model.predict(inference_ts_dataset.to_dataloader(train=False), raw=True)
##It is giving the following error
KeyError Traceback (most recent call last) ~\AppData\Roaming\Python\Python39\site-packages\pandas\core\indexes\base.py in get_loc(self, key, method, tolerance) 3801 try: -> 3802 return self._engine.get_loc(casted_key) 3803 except KeyError as err:
~\AppData\Roaming\Python\Python39\site-packages\pandas_libs\index.pyx in pandas._libs.index.IndexEngine.get_loc()
~\AppData\Roaming\Python\Python39\site-packages\pandas_libs\index.pyx in pandas._libs.index.IndexEngine.get_loc()
pandas_libs\hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()
pandas_libs\hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()
KeyError: 'total_sales'
The above exception was the direct cause of the following exception:
KeyError Traceback (most recent call last)
~\AppData\Local\Temp\ipykernel_27744\1361187920.py in
~\AppData\Roaming\Python\Python39\site-packages\pytorch_forecasting\data\timeseries.py in from_parameters(cls, parameters, data, stop_randomization, predict, **update_kwargs) 1198 parameters.update(update_kwargs) 1199 -> 1200 new = cls(data, **parameters) 1201 return new 1202
~\AppData\Roaming\Python\Python39\site-packages\pytorch_forecasting\data\timeseries.py in init(self, data, time_idx, target, group_ids, weight, max_encoder_length, min_encoder_length, min_prediction_idx, min_prediction_length, max_prediction_length, static_categoricals, static_reals, time_varying_known_categoricals, time_varying_known_reals, time_varying_unknown_categoricals, time_varying_unknown_reals, variable_groups, constant_fill_strategy, allow_missing_timesteps, lags, add_relative_time_idx, add_target_scales, add_encoder_length, target_normalizer, categorical_encoders, scalers, randomize_length, predict_mode) 474 475 # preprocess data --> 476 data = self._preprocess_data(data) 477 for target in self.target_names: 478 assert target not in self.scalers, "Target normalizer is separate and not in scalers."
~\AppData\Roaming\Python\Python39\site-packages\pytorch_forecasting\data\timeseries.py in _preprocess_data(self, data) 742 f"target{target}" not in data.columns 743 ), f"target{target} is a protected column and must not be present in data" --> 744 data[f"target{target}"] = data[target] 745 if self.weight is not None: 746 data["weight"] = data[self.weight]
~\AppData\Roaming\Python\Python39\site-packages\pandas\core\frame.py in getitem(self, key) 3805 if self.columns.nlevels > 1: 3806 return self._getitem_multilevel(key) -> 3807 indexer = self.columns.get_loc(key) 3808 if is_integer(indexer): 3809 indexer = [indexer]
~\AppData\Roaming\Python\Python39\site-packages\pandas\core\indexes\base.py in get_loc(self, key, method, tolerance) 3802 return self._engine.get_loc(casted_key) 3803 except KeyError as err: -> 3804 raise KeyError(key) from err 3805 except TypeError: 3806 # If we have a listlike key, _check_indexing_error will raise
KeyError: 'total_sales'
I'm assuming filtered_base_data
does not have the "total_sales" column. Since you are using the same parameters as training
(in which target
is set to "total_sales") to create inference_ts_dataset
, it complains because it can't find "total_sales" in filtered_base_data
.
I don't think this necessarily means there is data leakage in predict
.
@vishnu020 if that solved your problem, consider closing the issue.