gluonts icon indicating copy to clipboard operation
gluonts copied to clipboard

TemporalFusionTransformerEstimator gives error when using static categorical features using PandasDataset

Open jmberutich opened this issue 2 years ago • 8 comments

Description

TFT seems to have a bug in version 0.11 and no longer works with features when using a PandasDataset. I have only tried using static categorical features.

I have a long pandas dataframe with the following columns:

  • timestamp
  • item_id
  • target
  • static_cat_0
  • static_cat_1
  • static_cat_2
  • static_cat_3
  • static_cat_4
  • static_cat_5
  • static_cat_6
  • static_cat_7

To Reproduce

(Please provide minimal example of code snippet that reproduces the error. For existing examples, please provide link.)

import pandas as pd
from gluonts.dataset.pandas import PandasDataset
from gluonts.model.tft import TemporalFusionTransformerEstimator
from gluonts.mx.trainer import Trainer


df = pd.read_parquet("../data/long_df_sample.parquet")

feat_static_cat = [
        "static_cat_0",
        "static_cat_1",
        "static_cat_2",
        "static_cat_3",
        "static_cat_4",
        "static_cat_5",
        "static_cat_6",
        "static_cat_7"
]

cardinalities = df[feat_static_cat].nunique().tolist()

train = PandasDataset.from_long_dataframe(
    df,
    item_id="item_id",
    timestamp="timestamp",
    freq="M",
    feat_static_cat=feat_static_cat
)

estimator = TemporalFusionTransformerEstimator(
    freq="M",
    prediction_length=12,
    context_length=24,
    static_cardinalities={
        name: int(cardinality)
        for name, cardinality  in zip(feat_static_cat, cardinalities)
    },
    trainer=Trainer(
        epochs=10,
    ),
)

predictor = estimator.train(train)

Error message or code output

(Paste the complete error message, including stack trace, or the undesired output that the above snippet produces.)

KeyError                                  Traceback (most recent call last)
Cell In [24], line 1
----> 1 predictor = estimator.train(train)

File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/mx/model/estimator.py:238, in GluonEstimator.train(self, training_data, validation_data, shuffle_buffer_length, cache_data, **kwargs)
    230 def train(
    231     self,
    232     training_data: Dataset,
   (...)
    236     **kwargs,
    237 ) -> Predictor:
--> 238     return self.train_model(
    239         training_data=training_data,
    240         validation_data=validation_data,
    241         shuffle_buffer_length=shuffle_buffer_length,
    242         cache_data=cache_data,
    243     ).predictor

File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/mx/model/estimator.py:215, in GluonEstimator.train_model(self, training_data, validation_data, from_predictor, shuffle_buffer_length, cache_data)
    212 else:
    213     copy_parameters(from_predictor.network, training_network)
--> 215 self.trainer(
    216     net=training_network,
    217     train_iter=training_data_loader,
    218     validation_iter=validation_data_loader,
    219 )
    221 with self.trainer.ctx:
    222     predictor = self.create_predictor(transformation, training_network)

File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/mx/trainer/_base.py:410, in Trainer.__call__(self, net, train_iter, validation_iter)
    405 curr_lr = trainer.learning_rate
    406 logger.info(
    407     f"Epoch[{epoch_no}] Learning rate is {curr_lr}"
    408 )
--> 410 epoch_loss = loop(
    411     epoch_no,
    412     train_iter,
    413     num_batches_to_use=self.num_batches_per_epoch,
    414 )
    416 should_continue = self.callbacks.on_train_epoch_end(
    417     epoch_no=epoch_no,
    418     epoch_loss=loss_value(epoch_loss),
    419     training_network=net,
    420     trainer=trainer,
    421 )
    423 if is_validation_available:

File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/mx/trainer/_base.py:275, in Trainer.__call__.<locals>.loop(epoch_no, batch_iter, num_batches_to_use, is_training)
    272 it = tqdm(batch_iter, total=num_batches_to_use)
    273 any_batches = False
--> 275 for batch_no, batch in enumerate(it, start=1):
    276     any_batches = True
    278     # `batch` here is expected to be a dictionary whose fields
    279     # should correspond 1-to-1 with the network inputs
    280     # see below how `batch.values()` is fed into the network

File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/tqdm/std.py:1195, in tqdm.__iter__(self)
   1192 time = self._time
   1194 try:
-> 1195     for obj in iterable:
   1196         yield obj
   1197         # Update and possibly print the progressbar.
   1198         # Note: does not call self.update(1) for speed optimisation.

File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/itertools.py:175, in IterableSlice.__iter__(self)
    174 def __iter__(self):
--> 175     yield from itertools.islice(self.iterable, self.length)

File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/transform/_base.py:103, in TransformedDataset.__iter__(self)
    102 def __iter__(self) -> Iterator[DataEntry]:
--> 103     yield from self.transformation(
    104         self.base_dataset, is_train=self.is_train
    105     )

File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/transform/_base.py:124, in MapTransformation.__call__(self, data_it, is_train)
    121 def __call__(
    122     self, data_it: Iterable[DataEntry], is_train: bool
    123 ) -> Iterator:
--> 124     for data_entry in data_it:
    125         try:
    126             yield self.map_transform(data_entry.copy(), is_train)

File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/dataset/loader.py:37, in Batch.__call__(self, data, is_train)
     36 def __call__(self, data, is_train):
---> 37     yield from batcher(data, self.batch_size)

File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/itertools.py:94, in batcher.<locals>.get_batch()
     93 def get_batch():
---> 94     return list(itertools.islice(it, batch_size))

File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/transform/_base.py:124, in MapTransformation.__call__(self, data_it, is_train)
    121 def __call__(
    122     self, data_it: Iterable[DataEntry], is_train: bool
    123 ) -> Iterator:
--> 124     for data_entry in data_it:
    125         try:
    126             yield self.map_transform(data_entry.copy(), is_train)

File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/transform/_base.py:178, in FlatMapTransformation.__call__(self, data_it, is_train)
    174 def __call__(
    175     self, data_it: Iterable[DataEntry], is_train: bool
    176 ) -> Iterator:
    177     num_idle_transforms = 0
--> 178     for data_entry in data_it:
    179         num_idle_transforms += 1
    180         for result in self.flatmap_transform(data_entry.copy(), is_train):

File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/itertools.py:71, in Cyclic.__iter__(self)
     69 at_least_one = False
     70 while True:
---> 71     for el in self.iterable:
     72         at_least_one = True
     73         yield el

File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/transform/_base.py:103, in TransformedDataset.__iter__(self)
    102 def __iter__(self) -> Iterator[DataEntry]:
--> 103     yield from self.transformation(
    104         self.base_dataset, is_train=self.is_train
    105     )

File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/transform/_base.py:124, in MapTransformation.__call__(self, data_it, is_train)
    121 def __call__(
    122     self, data_it: Iterable[DataEntry], is_train: bool
    123 ) -> Iterator:
--> 124     for data_entry in data_it:
    125         try:
    126             yield self.map_transform(data_entry.copy(), is_train)

File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/transform/_base.py:124, in MapTransformation.__call__(self, data_it, is_train)
    121 def __call__(
    122     self, data_it: Iterable[DataEntry], is_train: bool
    123 ) -> Iterator:
--> 124     for data_entry in data_it:
    125         try:
    126             yield self.map_transform(data_entry.copy(), is_train)

    [... skipping similar frames: MapTransformation.__call__ at line 124 (19 times)]

File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/transform/_base.py:124, in MapTransformation.__call__(self, data_it, is_train)
    121 def __call__(
    122     self, data_it: Iterable[DataEntry], is_train: bool
    123 ) -> Iterator:
--> 124     for data_entry in data_it:
    125         try:
    126             yield self.map_transform(data_entry.copy(), is_train)

File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/transform/_base.py:128, in MapTransformation.__call__(self, data_it, is_train)
    126     yield self.map_transform(data_entry.copy(), is_train)
    127 except Exception as e:
--> 128     raise e

File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/transform/_base.py:126, in MapTransformation.__call__(self, data_it, is_train)
    124 for data_entry in data_it:
    125     try:
--> 126         yield self.map_transform(data_entry.copy(), is_train)
    127     except Exception as e:
    128         raise e

File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/transform/_base.py:141, in SimpleTransformation.map_transform(self, data, is_train)
    140 def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry:
--> 141     return self.transform(data)

File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/transform/convert.py:127, in AsNumpyArray.transform(self, data)
    126 def transform(self, data: DataEntry) -> DataEntry:
--> 127     value = np.asarray(data[self.field], dtype=self.dtype)
    129     assert_data_error(
    130         value.ndim == self.expected_ndim,
    131         'Input for field "{self.field}" does not have the required'
   (...)
    135         self=self,
    136     )
    137     data[self.field] = value

KeyError: 'static_cat_0'

Environment

  • Operating system: Ubuntu
  • Python version: 3.8.10
  • GluonTS version: 0.11
  • MXNet version: 1.9.1

(Add as much information about your environment as possible, e.g. dependencies versions.)

jmberutich avatar Nov 04 '22 12:11 jmberutich

TFT seems to have a bug in version 0.11

Does the example work in 0.10?

lostella avatar Nov 04 '22 22:11 lostella

I've also noticed TFT doesn't play well with PandasDataset dynamic real features. I think it's because TFT expects the resulting ListDataset to include the features by name, not grouped into feat_dynamic_real, etc.

RE: the working example in https://github.com/awslabs/gluonts/issues/1075. All features are specified by name in the ListDataset, not by feature type.

esbraun avatar Nov 06 '22 21:11 esbraun

TFT seems to have a bug in version 0.11

Does the example work in 0.10?

No, it doesnt work in 0.10 also

jmberutich avatar Nov 07 '22 08:11 jmberutich

As @esbraun pointed out, this is because TFT expects a different schema than what PandasDataset provides.

Cc @jaheba this is something that fits in the whole schema story. Should we have a milestone or project to collect all issues related to it?

lostella avatar Nov 09 '22 10:11 lostella

SelfAttentionEstimator give me this error even though set the param:use_feat_static_cat=False: KeyError: 'feat_static_cat'

qiniuweihe avatar Dec 15 '22 08:12 qiniuweihe

There are a couple of issues related to this model: https://github.com/awslabs/gluonts/issues/1976, https://github.com/awslabs/gluonts/issues/2160, https://github.com/awslabs/gluonts/issues/2416, https://github.com/awslabs/gluonts/issues/2466. All them report something similar to this ticket. We already started a process to revisit this model and fix all these issues. We will keep you updated.

melopeo avatar Dec 15 '22 10:12 melopeo

@jmberutich we now have a pytorch implementation of temporal fusion transformer, see #2536, which has the same data interface as other models, so it should not have this issue. It's in the dev branch until the next release, let us know in case you have the chance to try it out!

lostella avatar Jan 17 '23 13:01 lostella

Are there any fixes for it?

baniasbaabe avatar Apr 03 '23 14:04 baniasbaabe