TypeError when predicting with TFT
First time user here and I am following this example: https://pytorch-forecasting.readthedocs.io/en/latest/tutorials/stallion.html and my code is this:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
import torch
from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.metrics import SMAPE, PoissonLoss, QuantileLoss
...
validation_len = int(len(Y_df.ds.unique()) * 0.4)
training_cutoff = int(len(Y_df.ds.unique())) - validation_len
print('Len of training data is : ',len(Y_df[:training_cutoff]))
print('Len of val data is : ',len(Y_df[training_cutoff:]))
train_dataset = TimeSeriesDataSet(
data=Y_df[:training_cutoff],
group_ids=["unique_id"],
target="y",
time_idx="time_idx",
min_encoder_length=5,
max_encoder_length=5,
min_prediction_length=1,
max_prediction_length=1,
time_varying_unknown_reals=["y"],
)
# create validation set (predict=True) which means to predict the last max_prediction_length points in time
val_dataset = TimeSeriesDataSet.from_dataset(train_dataset, Y_df, predict=True, stop_randomization=True)
# create dataloaders for model
batch_size = 4
train_dataloader = train_dataset.to_dataloader(train=True, batch_size=batch_size, num_workers=4)
val_dataloader = val_dataset.to_dataloader(train=False, batch_size=batch_size, num_workers=4)
print(Y_train_df.head())
print(Y_val_df.head())
print(train_dataset[0])
print(val_dataset[0])
model = TemporalFusionTransformer(
output_size=1,
learning_rate=0.03,
hidden_size=16,
attention_head_size=1,
dropout=0.1
hidden_continuous_size=8,
loss=QuantileLoss(),
reduce_on_plateau_patience=4,
)
trainer = pl.Trainer(
max_epochs=10,
progress_bar_refresh_rate=1,
log_every_n_steps=5,
check_val_every_n_epoch=5,
)
trainer.fit(model, train_dataloader, val_dataloader)
torch.save(model, "model.pth")
which outputs:
Len of training data is : 5169
Len of val data is : 3446
ds y unique_id time_idx
0 2020-05-13 08:45:57.228000000 1575.6520 dummy 0
1 2020-05-13 08:46:58.343000064 1575.6520 dummy 1
2 2020-05-13 08:47:59.299000064 1527.7666 dummy 2
3 2020-05-13 08:49:00.236000000 1527.7666 dummy 3
4 2020-05-13 08:50:01.188999936 1477.7880 dummy 4
({'x_cat': tensor([], size=(6, 0), dtype=torch.int64), 'x_cont': tensor([[-0.0258],
[-0.0258],
[-0.1493],
[-0.1493],
[-0.2782],
[-0.2786]]), 'encoder_length': 5, 'decoder_length': 1, 'encoder_target': tensor([1575.6520, 1575.6520, 1527.7666, 1527.7666, 1477.7880]), 'encoder_time_idx_start': tensor(0), 'groups': tensor([0]), 'target_scale': array([1585.67038468, 387.81209598])}, (tensor([1477.6305]), None))
({'x_cat': tensor([], size=(6, 0), dtype=torch.int64), 'x_cont': tensor([[-1.5680],
[-1.4909],
[-1.4564],
[-1.4564],
[-1.3573],
[-1.4440]]), 'encoder_length': 5, 'decoder_length': 1, 'encoder_target': tensor([ 977.5710, 1007.4689, 1020.8758, 1020.8758, 1059.2924]), 'encoder_time_idx_start': tensor(8609), 'groups': tensor([0]), 'target_scale': array([1585.67038468, 387.81209598])}, (tensor([1025.6858]), None))
along with some deprecation warnings and the training completion.
However, when I try to do:
outputs = trainer.predict(model, val_dataloader)
but it throws an error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/tmp/ipykernel_17/1941000738.py in <module>
----> 1 outputs = trainer.predict(model, val_dataloader)
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in predict(self, model, dataloaders, datamodule, return_predictions, ckpt_path)
1024 self.strategy.model = model or self.lightning_module
1025 return self._call_and_handle_interrupt(
-> 1026 self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path
1027 )
1028
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
721 return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
722 else:
--> 723 return trainer_fn(*args, **kwargs)
724 # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
725 except KeyboardInterrupt as exception:
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _predict_impl(self, model, dataloaders, datamodule, return_predictions, ckpt_path)
1070 self._predicted_ckpt_path = self.ckpt_path # TODO: remove in v1.8
1071
-> 1072 results = self._run(model, ckpt_path=self.ckpt_path)
1073
1074 assert self.state.stopped
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model, ckpt_path)
1234 self._checkpoint_connector.resume_end()
1235
-> 1236 results = self._run_stage()
1237
1238 log.detail(f"{self.__class__.__name__}: trainer tearing down")
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run_stage(self)
1320 return self._run_evaluate()
1321 if self.predicting:
-> 1322 return self._run_predict()
1323 return self._run_train()
1324
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run_predict(self)
1379 self.predict_loop.trainer = self
1380 with torch.no_grad():
-> 1381 return self.predict_loop.run()
1382
1383 def _run_sanity_check(self) -> None:
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
202 try:
203 self.on_advance_start(*args, **kwargs)
--> 204 self.advance(*args, **kwargs)
205 self.on_advance_end()
206 self._restarting = False
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/loops/dataloader/prediction_loop.py in advance(self, *args, **kwargs)
96
97 dl_predictions, dl_batch_indices = self.epoch_loop.run(
---> 98 dataloader_iter, self.current_dataloader_idx, dl_max_batches, self.num_dataloaders, self.return_predictions
99 )
100 self.predictions.append(dl_predictions)
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
202 try:
203 self.on_advance_start(*args, **kwargs)
--> 204 self.advance(*args, **kwargs)
205 self.on_advance_end()
206 self._restarting = False
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/prediction_epoch_loop.py in advance(self, dataloader_iter, dataloader_idx, dl_max_batches, num_dataloaders, return_predictions)
102 self.batch_progress.increment_ready()
103
--> 104 self._predict_step(batch, batch_idx, dataloader_idx)
105
106 def on_run_end(self) -> Tuple[List[Any], List[List[int]]]:
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/prediction_epoch_loop.py in _predict_step(self, batch, batch_idx, dataloader_idx)
130 self.batch_progress.increment_started()
131
--> 132 predictions = self.trainer._call_strategy_hook("predict_step", *step_kwargs.values())
133
134 self.batch_progress.increment_processed()
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _call_strategy_hook(self, hook_name, *args, **kwargs)
1763
1764 with self.profiler.profile(f"[Strategy]{self.strategy.__class__.__name__}.{hook_name}"):
-> 1765 output = fn(*args, **kwargs)
1766
1767 # restore current_fx when nested context
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/strategies/strategy.py in predict_step(self, *args, **kwargs)
358 """
359 with self.precision_plugin.predict_step_context():
--> 360 return self.model.predict_step(*args, **kwargs)
361
362 def training_step_end(self, output):
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/core/lightning.py in predict_step(self, batch, batch_idx, dataloader_idx)
1151 Predicted output
1152 """
-> 1153 return self(batch)
1154
1155 def configure_callbacks(self) -> Union[Sequence[Callback], Callback]:
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
/opt/conda/lib/python3.7/site-packages/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py in forward(self, x)
396 input dimensions: n_samples x time x variables
397 """
--> 398 encoder_lengths = x["encoder_lengths"]
399 decoder_lengths = x["decoder_lengths"]
400 x_cat = torch.cat([x["encoder_cat"], x["decoder_cat"]], dim=1) # concatenate in time dimension
TypeError: tuple indices must be integers or slices, not str
If I instead try to do (regardless of the batches number that I've tried):
raw_predictions, x = model.predict(val_dataloader, mode="raw", return_x=True)
I then get a different error:
StopIteration Traceback (most recent call last)
/tmp/ipykernel_17/1197041394.py in <module>
1 # raw predictions are a dictionary from which all kind of information including quantiles can be extracted
----> 2 raw_predictions, x = model.predict(val_dataloader, mode="raw", return_x=True)
/opt/conda/lib/python3.7/site-packages/pytorch_forecasting/models/base_model.py in predict(self, data, mode, return_index, return_decoder_lengths, batch_size, num_workers, fast_dev_run, show_progress_bar, return_x, mode_kwargs, **kwargs)
1157
1158 # make prediction
-> 1159 out = self(x, **kwargs) # raw output is dictionary
1160
1161 lengths = x["decoder_lengths"]
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
/opt/conda/lib/python3.7/site-packages/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py in forward(self, x)
431 embeddings_varying_encoder, encoder_sparse_weights = self.encoder_variable_selection(
432 embeddings_varying_encoder,
--> 433 static_context_variable_selection[:, :max_encoder_length],
434 )
435
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
/opt/conda/lib/python3.7/site-packages/pytorch_forecasting/models/temporal_fusion_transformer/sub_modules.py in forward(self, x, context)
339 outputs = outputs.sum(dim=-1)
340 else: # for one input, do not perform variable selection but just encoding
--> 341 name = next(iter(self.single_variable_grns.keys()))
342 variable_embedding = x[name]
343 if name in self.prescalers:
StopIteration:
Any idea what I am missing please?
I'm getting the same thing. I was able to work around it by modifying the "predict_step" method to use just the first input of the tuple before calling the forward pass:
def predict_step(self, batch, batch_idx, dataloader_idx):
return self(batch[0])
But I have no idea why training will work just fine but not predicting. I also hope that extra data isn't important, but I have no clue.