pytorch-forecasting icon indicating copy to clipboard operation
pytorch-forecasting copied to clipboard

CPU RAM consumption steadily increases during TFT training

Open jon-huff opened this issue 1 year ago • 5 comments

  • PyTorch-Forecasting version: 1.0.0
  • PyTorch version: 2.0.1
  • Python version: 3.10
  • Operating System: Ubuntu 22.04

Expected behavior

TFT training on UCI electricity dataset is coded up per examples using Lightning, train/val datasets fits well within system memory.

Actual behavior

Throughout training, cpu RAM consumption steadily grows until it hits OOM (96gb) and kernel crashes. I have tried disabling logging everywhere I know how within Pytorch-Forecasting and Lightning to no avail. GPU ram consumption stays steady at ~1.2gb

Code to reproduce the problem

max_prediction_length = 24
max_encoder_length = 7*24
batch_size = 64
num_workers=8
split_time_idx = 30000

train_data = TimeSeriesDataSet(data=df[lambda x: x.time_idx < split_time_idx],
                               time_idx='time_idx',
                               target='demand',
                               group_ids=['group_id'],
                               min_encoder_length=max_encoder_length,
                               max_encoder_length=max_encoder_length,
                               min_prediction_length=max_prediction_length,
                               max_prediction_length=max_prediction_length,
                               static_categoricals=['group_id'],
                               time_varying_known_reals=['time_idx', 'hour', 'weekday', 'day', 'month'],
                               time_varying_unknown_reals=['demand'],
                               target_normalizer=GroupNormalizer(groups=['group_id'], transformation='softplus'),
                               add_relative_time_idx=True,
                               add_target_scales=True,
                               randomize_length=False)

val_data = TimeSeriesDataSet.from_dataset(train_data, df[lambda x: x.time_idx >= split_time_idx], stop_randomization=True, predict=False)

train_dataloader = train_data.to_dataloader(train=True, batch_size=batch_size, num_workers=num_workers)
val_dataloader = val_data.to_dataloader(train=False, batch_size=batch_size, num_workers=num_workers)

tft = TemporalFusionTransformer.from_dataset(train_data,
                                             learning_rate=.001,
                                             hidden_size=160,
                                             hidden_continuous_size=160,
                                             attention_head_size=4,
                                             dropout=.1,
                                             output_size=output_size,
                                             loss=quantile_loss,
                                             log_interval=-1,
                                             reduce_on_plateau_patience=4)

jon-huff avatar Aug 31 '23 14:08 jon-huff

Could you find any answers

sairamtvv avatar Sep 29 '23 09:09 sairamtvv

In my experience with this problem, I've found that the memory consumption stops increasing and goes down (a little bit) after an epoch is complete. Therefore, in my circumstance, if I raise the batch size to an arbitrarily large amount, the problem goes away. This is more of a hack than a fix though.

I've also found that if I set the num_workers to 0 and leave the batch size as is, then I don't run into running out of memory. Again this is more of a hack than a fix.

It "feels" like a memory leak. I'm just not sure if just the dataloader is the problem, or if it's happening during training. I'm also not sure if pytorch-forecasting is the problem or the bug exists with pytorch-lightning or torch itself.

I'd love to try running the TFT on the UCI electricity dataset. Did you do anything prior to the code you posted? Did you just download a CSV file, read it into pandas and store it into the variable df?

YojoNick avatar Sep 29 '23 12:09 YojoNick

I have run with the earlier version of pytorch-forecasting and it is working seeamlessly. These problems could be from upgrade of torch or pytorchlightning.

sairamtvv avatar Sep 29 '23 16:09 sairamtvv

@sairamtvv Can you please share the versions of torch, forecasting, lightning, and other useful packages you are using ?

e-platini avatar Oct 05 '23 12:10 e-platini

@jon-huff .... I'm experiencing a similar issue with 'pytorch_forecasting' version 1.0.0 (latest version). Everything was working seamlessly fine two months ago, but when I recently ran the same code in the same environment, I noticed a significant increase in RAM consumption. Specifically, whole RAM pf my system (64GB) is getting consumed even when a single epoch is completed.

AshiHydro avatar Dec 23 '23 13:12 AshiHydro