gluonts
gluonts copied to clipboard
TemporalFusionTransformerEstimator gives error when using static categorical features using PandasDataset
Description
TFT seems to have a bug in version 0.11 and no longer works with features when using a PandasDataset. I have only tried using static categorical features.
I have a long pandas dataframe with the following columns:
- timestamp
- item_id
- target
- static_cat_0
- static_cat_1
- static_cat_2
- static_cat_3
- static_cat_4
- static_cat_5
- static_cat_6
- static_cat_7
To Reproduce
(Please provide minimal example of code snippet that reproduces the error. For existing examples, please provide link.)
import pandas as pd
from gluonts.dataset.pandas import PandasDataset
from gluonts.model.tft import TemporalFusionTransformerEstimator
from gluonts.mx.trainer import Trainer
df = pd.read_parquet("../data/long_df_sample.parquet")
feat_static_cat = [
"static_cat_0",
"static_cat_1",
"static_cat_2",
"static_cat_3",
"static_cat_4",
"static_cat_5",
"static_cat_6",
"static_cat_7"
]
cardinalities = df[feat_static_cat].nunique().tolist()
train = PandasDataset.from_long_dataframe(
df,
item_id="item_id",
timestamp="timestamp",
freq="M",
feat_static_cat=feat_static_cat
)
estimator = TemporalFusionTransformerEstimator(
freq="M",
prediction_length=12,
context_length=24,
static_cardinalities={
name: int(cardinality)
for name, cardinality in zip(feat_static_cat, cardinalities)
},
trainer=Trainer(
epochs=10,
),
)
predictor = estimator.train(train)
Error message or code output
(Paste the complete error message, including stack trace, or the undesired output that the above snippet produces.)
KeyError Traceback (most recent call last)
Cell In [24], line 1
----> 1 predictor = estimator.train(train)
File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/mx/model/estimator.py:238, in GluonEstimator.train(self, training_data, validation_data, shuffle_buffer_length, cache_data, **kwargs)
230 def train(
231 self,
232 training_data: Dataset,
(...)
236 **kwargs,
237 ) -> Predictor:
--> 238 return self.train_model(
239 training_data=training_data,
240 validation_data=validation_data,
241 shuffle_buffer_length=shuffle_buffer_length,
242 cache_data=cache_data,
243 ).predictor
File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/mx/model/estimator.py:215, in GluonEstimator.train_model(self, training_data, validation_data, from_predictor, shuffle_buffer_length, cache_data)
212 else:
213 copy_parameters(from_predictor.network, training_network)
--> 215 self.trainer(
216 net=training_network,
217 train_iter=training_data_loader,
218 validation_iter=validation_data_loader,
219 )
221 with self.trainer.ctx:
222 predictor = self.create_predictor(transformation, training_network)
File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/mx/trainer/_base.py:410, in Trainer.__call__(self, net, train_iter, validation_iter)
405 curr_lr = trainer.learning_rate
406 logger.info(
407 f"Epoch[{epoch_no}] Learning rate is {curr_lr}"
408 )
--> 410 epoch_loss = loop(
411 epoch_no,
412 train_iter,
413 num_batches_to_use=self.num_batches_per_epoch,
414 )
416 should_continue = self.callbacks.on_train_epoch_end(
417 epoch_no=epoch_no,
418 epoch_loss=loss_value(epoch_loss),
419 training_network=net,
420 trainer=trainer,
421 )
423 if is_validation_available:
File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/mx/trainer/_base.py:275, in Trainer.__call__.<locals>.loop(epoch_no, batch_iter, num_batches_to_use, is_training)
272 it = tqdm(batch_iter, total=num_batches_to_use)
273 any_batches = False
--> 275 for batch_no, batch in enumerate(it, start=1):
276 any_batches = True
278 # `batch` here is expected to be a dictionary whose fields
279 # should correspond 1-to-1 with the network inputs
280 # see below how `batch.values()` is fed into the network
File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/tqdm/std.py:1195, in tqdm.__iter__(self)
1192 time = self._time
1194 try:
-> 1195 for obj in iterable:
1196 yield obj
1197 # Update and possibly print the progressbar.
1198 # Note: does not call self.update(1) for speed optimisation.
File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/itertools.py:175, in IterableSlice.__iter__(self)
174 def __iter__(self):
--> 175 yield from itertools.islice(self.iterable, self.length)
File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/transform/_base.py:103, in TransformedDataset.__iter__(self)
102 def __iter__(self) -> Iterator[DataEntry]:
--> 103 yield from self.transformation(
104 self.base_dataset, is_train=self.is_train
105 )
File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/transform/_base.py:124, in MapTransformation.__call__(self, data_it, is_train)
121 def __call__(
122 self, data_it: Iterable[DataEntry], is_train: bool
123 ) -> Iterator:
--> 124 for data_entry in data_it:
125 try:
126 yield self.map_transform(data_entry.copy(), is_train)
File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/dataset/loader.py:37, in Batch.__call__(self, data, is_train)
36 def __call__(self, data, is_train):
---> 37 yield from batcher(data, self.batch_size)
File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/itertools.py:94, in batcher.<locals>.get_batch()
93 def get_batch():
---> 94 return list(itertools.islice(it, batch_size))
File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/transform/_base.py:124, in MapTransformation.__call__(self, data_it, is_train)
121 def __call__(
122 self, data_it: Iterable[DataEntry], is_train: bool
123 ) -> Iterator:
--> 124 for data_entry in data_it:
125 try:
126 yield self.map_transform(data_entry.copy(), is_train)
File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/transform/_base.py:178, in FlatMapTransformation.__call__(self, data_it, is_train)
174 def __call__(
175 self, data_it: Iterable[DataEntry], is_train: bool
176 ) -> Iterator:
177 num_idle_transforms = 0
--> 178 for data_entry in data_it:
179 num_idle_transforms += 1
180 for result in self.flatmap_transform(data_entry.copy(), is_train):
File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/itertools.py:71, in Cyclic.__iter__(self)
69 at_least_one = False
70 while True:
---> 71 for el in self.iterable:
72 at_least_one = True
73 yield el
File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/transform/_base.py:103, in TransformedDataset.__iter__(self)
102 def __iter__(self) -> Iterator[DataEntry]:
--> 103 yield from self.transformation(
104 self.base_dataset, is_train=self.is_train
105 )
File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/transform/_base.py:124, in MapTransformation.__call__(self, data_it, is_train)
121 def __call__(
122 self, data_it: Iterable[DataEntry], is_train: bool
123 ) -> Iterator:
--> 124 for data_entry in data_it:
125 try:
126 yield self.map_transform(data_entry.copy(), is_train)
File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/transform/_base.py:124, in MapTransformation.__call__(self, data_it, is_train)
121 def __call__(
122 self, data_it: Iterable[DataEntry], is_train: bool
123 ) -> Iterator:
--> 124 for data_entry in data_it:
125 try:
126 yield self.map_transform(data_entry.copy(), is_train)
[... skipping similar frames: MapTransformation.__call__ at line 124 (19 times)]
File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/transform/_base.py:124, in MapTransformation.__call__(self, data_it, is_train)
121 def __call__(
122 self, data_it: Iterable[DataEntry], is_train: bool
123 ) -> Iterator:
--> 124 for data_entry in data_it:
125 try:
126 yield self.map_transform(data_entry.copy(), is_train)
File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/transform/_base.py:128, in MapTransformation.__call__(self, data_it, is_train)
126 yield self.map_transform(data_entry.copy(), is_train)
127 except Exception as e:
--> 128 raise e
File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/transform/_base.py:126, in MapTransformation.__call__(self, data_it, is_train)
124 for data_entry in data_it:
125 try:
--> 126 yield self.map_transform(data_entry.copy(), is_train)
127 except Exception as e:
128 raise e
File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/transform/_base.py:141, in SimpleTransformation.map_transform(self, data, is_train)
140 def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry:
--> 141 return self.transform(data)
File ~/kost-2.0/models/gluonts/.venv/lib/python3.8/site-packages/gluonts/transform/convert.py:127, in AsNumpyArray.transform(self, data)
126 def transform(self, data: DataEntry) -> DataEntry:
--> 127 value = np.asarray(data[self.field], dtype=self.dtype)
129 assert_data_error(
130 value.ndim == self.expected_ndim,
131 'Input for field "{self.field}" does not have the required'
(...)
135 self=self,
136 )
137 data[self.field] = value
KeyError: 'static_cat_0'
Environment
- Operating system: Ubuntu
- Python version: 3.8.10
- GluonTS version: 0.11
- MXNet version: 1.9.1
(Add as much information about your environment as possible, e.g. dependencies versions.)
TFT seems to have a bug in version 0.11
Does the example work in 0.10?
I've also noticed TFT doesn't play well with PandasDataset dynamic real features. I think it's because TFT expects the resulting ListDataset to include the features by name, not grouped into feat_dynamic_real, etc.
RE: the working example in https://github.com/awslabs/gluonts/issues/1075. All features are specified by name in the ListDataset, not by feature type.
TFT seems to have a bug in version 0.11
Does the example work in 0.10?
No, it doesnt work in 0.10 also
As @esbraun pointed out, this is because TFT expects a different schema than what PandasDataset
provides.
Cc @jaheba this is something that fits in the whole schema story. Should we have a milestone or project to collect all issues related to it?
SelfAttentionEstimator give me this error even though set the param:use_feat_static_cat=False: KeyError: 'feat_static_cat'
There are a couple of issues related to this model: https://github.com/awslabs/gluonts/issues/1976, https://github.com/awslabs/gluonts/issues/2160, https://github.com/awslabs/gluonts/issues/2416, https://github.com/awslabs/gluonts/issues/2466. All them report something similar to this ticket. We already started a process to revisit this model and fix all these issues. We will keep you updated.
@jmberutich we now have a pytorch implementation of temporal fusion transformer, see #2536, which has the same data interface as other models, so it should not have this issue. It's in the dev
branch until the next release, let us know in case you have the chance to try it out!
Are there any fixes for it?