pytorch-forecasting icon indicating copy to clipboard operation
pytorch-forecasting copied to clipboard

TypeError when predicting with TFT

Open gsamaras opened this issue 3 years ago • 1 comments

First time user here and I am following this example: https://pytorch-forecasting.readthedocs.io/en/latest/tutorials/stallion.html and my code is this:

    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import pytorch_lightning as pl
    from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
    from pytorch_lightning.loggers import TensorBoardLogger
    import torch
    
    from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
    from pytorch_forecasting.metrics import SMAPE, PoissonLoss, QuantileLoss
    ...

    validation_len = int(len(Y_df.ds.unique()) * 0.4)
    training_cutoff = int(len(Y_df.ds.unique())) - validation_len

    print('Len of training data is : ',len(Y_df[:training_cutoff]))
    print('Len of val data is : ',len(Y_df[training_cutoff:]))
    train_dataset = TimeSeriesDataSet(
        data=Y_df[:training_cutoff],
        group_ids=["unique_id"],
        target="y",
        time_idx="time_idx",
        min_encoder_length=5,
        max_encoder_length=5,
        min_prediction_length=1,
        max_prediction_length=1,
        time_varying_unknown_reals=["y"],
    )

    # create validation set (predict=True) which means to predict the last max_prediction_length points in time
    val_dataset = TimeSeriesDataSet.from_dataset(train_dataset, Y_df, predict=True, stop_randomization=True)

    # create dataloaders for model
    batch_size = 4
    train_dataloader = train_dataset.to_dataloader(train=True, batch_size=batch_size, num_workers=4)
    val_dataloader = val_dataset.to_dataloader(train=False, batch_size=batch_size, num_workers=4)

    print(Y_train_df.head())
    print(Y_val_df.head())
    print(train_dataset[0])
    print(val_dataset[0])

     model = TemporalFusionTransformer(
        output_size=1,
        learning_rate=0.03,
        hidden_size=16,
        attention_head_size=1,
        dropout=0.1
        hidden_continuous_size=8,
        loss=QuantileLoss(),
        reduce_on_plateau_patience=4,
    )

    trainer = pl.Trainer(
        max_epochs=10,
        progress_bar_refresh_rate=1,
        log_every_n_steps=5,
        check_val_every_n_epoch=5,
    )
    trainer.fit(model, train_dataloader, val_dataloader)
    torch.save(model, "model.pth")
   

which outputs:

Len of training data is :  5169
Len of val data is :  3446
                             ds          y unique_id  time_idx
0 2020-05-13 08:45:57.228000000  1575.6520     dummy         0
1 2020-05-13 08:46:58.343000064  1575.6520     dummy         1
2 2020-05-13 08:47:59.299000064  1527.7666     dummy         2
3 2020-05-13 08:49:00.236000000  1527.7666     dummy         3
4 2020-05-13 08:50:01.188999936  1477.7880     dummy         4
    
({'x_cat': tensor([], size=(6, 0), dtype=torch.int64), 'x_cont': tensor([[-0.0258],
        [-0.0258],
        [-0.1493],
        [-0.1493],
        [-0.2782],
        [-0.2786]]), 'encoder_length': 5, 'decoder_length': 1, 'encoder_target': tensor([1575.6520, 1575.6520, 1527.7666, 1527.7666, 1477.7880]), 'encoder_time_idx_start': tensor(0), 'groups': tensor([0]), 'target_scale': array([1585.67038468,  387.81209598])}, (tensor([1477.6305]), None))
({'x_cat': tensor([], size=(6, 0), dtype=torch.int64), 'x_cont': tensor([[-1.5680],
        [-1.4909],
        [-1.4564],
        [-1.4564],
        [-1.3573],
        [-1.4440]]), 'encoder_length': 5, 'decoder_length': 1, 'encoder_target': tensor([ 977.5710, 1007.4689, 1020.8758, 1020.8758, 1059.2924]), 'encoder_time_idx_start': tensor(8609), 'groups': tensor([0]), 'target_scale': array([1585.67038468,  387.81209598])}, (tensor([1025.6858]), None))

along with some deprecation warnings and the training completion.

However, when I try to do:

outputs = trainer.predict(model, val_dataloader)

but it throws an error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_17/1941000738.py in <module>
----> 1 outputs = trainer.predict(model, val_dataloader)

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in predict(self, model, dataloaders, datamodule, return_predictions, ckpt_path)
   1024         self.strategy.model = model or self.lightning_module
   1025         return self._call_and_handle_interrupt(
-> 1026             self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path
   1027         )
   1028 

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
    721                 return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
    722             else:
--> 723                 return trainer_fn(*args, **kwargs)
    724         # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
    725         except KeyboardInterrupt as exception:

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _predict_impl(self, model, dataloaders, datamodule, return_predictions, ckpt_path)
   1070         self._predicted_ckpt_path = self.ckpt_path  # TODO: remove in v1.8
   1071 
-> 1072         results = self._run(model, ckpt_path=self.ckpt_path)
   1073 
   1074         assert self.state.stopped

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model, ckpt_path)
   1234         self._checkpoint_connector.resume_end()
   1235 
-> 1236         results = self._run_stage()
   1237 
   1238         log.detail(f"{self.__class__.__name__}: trainer tearing down")

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run_stage(self)
   1320             return self._run_evaluate()
   1321         if self.predicting:
-> 1322             return self._run_predict()
   1323         return self._run_train()
   1324 

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run_predict(self)
   1379         self.predict_loop.trainer = self
   1380         with torch.no_grad():
-> 1381             return self.predict_loop.run()
   1382 
   1383     def _run_sanity_check(self) -> None:

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
    202             try:
    203                 self.on_advance_start(*args, **kwargs)
--> 204                 self.advance(*args, **kwargs)
    205                 self.on_advance_end()
    206                 self._restarting = False

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/loops/dataloader/prediction_loop.py in advance(self, *args, **kwargs)
     96 
     97         dl_predictions, dl_batch_indices = self.epoch_loop.run(
---> 98             dataloader_iter, self.current_dataloader_idx, dl_max_batches, self.num_dataloaders, self.return_predictions
     99         )
    100         self.predictions.append(dl_predictions)

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
    202             try:
    203                 self.on_advance_start(*args, **kwargs)
--> 204                 self.advance(*args, **kwargs)
    205                 self.on_advance_end()
    206                 self._restarting = False

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/prediction_epoch_loop.py in advance(self, dataloader_iter, dataloader_idx, dl_max_batches, num_dataloaders, return_predictions)
    102         self.batch_progress.increment_ready()
    103 
--> 104         self._predict_step(batch, batch_idx, dataloader_idx)
    105 
    106     def on_run_end(self) -> Tuple[List[Any], List[List[int]]]:

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/prediction_epoch_loop.py in _predict_step(self, batch, batch_idx, dataloader_idx)
    130         self.batch_progress.increment_started()
    131 
--> 132         predictions = self.trainer._call_strategy_hook("predict_step", *step_kwargs.values())
    133 
    134         self.batch_progress.increment_processed()

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _call_strategy_hook(self, hook_name, *args, **kwargs)
   1763 
   1764         with self.profiler.profile(f"[Strategy]{self.strategy.__class__.__name__}.{hook_name}"):
-> 1765             output = fn(*args, **kwargs)
   1766 
   1767         # restore current_fx when nested context

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/strategies/strategy.py in predict_step(self, *args, **kwargs)
    358         """
    359         with self.precision_plugin.predict_step_context():
--> 360             return self.model.predict_step(*args, **kwargs)
    361 
    362     def training_step_end(self, output):

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/core/lightning.py in predict_step(self, batch, batch_idx, dataloader_idx)
   1151             Predicted output
   1152         """
-> 1153         return self(batch)
   1154 
   1155     def configure_callbacks(self) -> Union[Sequence[Callback], Callback]:

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/lib/python3.7/site-packages/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py in forward(self, x)
    396         input dimensions: n_samples x time x variables
    397         """
--> 398         encoder_lengths = x["encoder_lengths"]
    399         decoder_lengths = x["decoder_lengths"]
    400         x_cat = torch.cat([x["encoder_cat"], x["decoder_cat"]], dim=1)  # concatenate in time dimension

TypeError: tuple indices must be integers or slices, not str

If I instead try to do (regardless of the batches number that I've tried):

raw_predictions, x = model.predict(val_dataloader, mode="raw", return_x=True)

I then get a different error:

StopIteration                             Traceback (most recent call last)
/tmp/ipykernel_17/1197041394.py in <module>
      1 # raw predictions are a dictionary from which all kind of information including quantiles can be extracted
----> 2 raw_predictions, x = model.predict(val_dataloader, mode="raw", return_x=True)

/opt/conda/lib/python3.7/site-packages/pytorch_forecasting/models/base_model.py in predict(self, data, mode, return_index, return_decoder_lengths, batch_size, num_workers, fast_dev_run, show_progress_bar, return_x, mode_kwargs, **kwargs)
   1157 
   1158                 # make prediction
-> 1159                 out = self(x, **kwargs)  # raw output is dictionary
   1160 
   1161                 lengths = x["decoder_lengths"]

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/lib/python3.7/site-packages/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py in forward(self, x)
    431         embeddings_varying_encoder, encoder_sparse_weights = self.encoder_variable_selection(
    432             embeddings_varying_encoder,
--> 433             static_context_variable_selection[:, :max_encoder_length],
    434         )
    435 

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/lib/python3.7/site-packages/pytorch_forecasting/models/temporal_fusion_transformer/sub_modules.py in forward(self, x, context)
    339             outputs = outputs.sum(dim=-1)
    340         else:  # for one input, do not perform variable selection but just encoding
--> 341             name = next(iter(self.single_variable_grns.keys()))
    342             variable_embedding = x[name]
    343             if name in self.prescalers:

StopIteration: 

Any idea what I am missing please?

gsamaras avatar Aug 04 '22 13:08 gsamaras

I'm getting the same thing. I was able to work around it by modifying the "predict_step" method to use just the first input of the tuple before calling the forward pass:

def predict_step(self,  batch, batch_idx, dataloader_idx):
    return self(batch[0])

But I have no idea why training will work just fine but not predicting. I also hope that extra data isn't important, but I have no clue.

ChosunOne avatar Mar 09 '23 22:03 ChosunOne