pytorch-forecasting
pytorch-forecasting copied to clipboard
CPU RAM consumption steadily increases during TFT training
- PyTorch-Forecasting version: 1.0.0
- PyTorch version: 2.0.1
- Python version: 3.10
- Operating System: Ubuntu 22.04
Expected behavior
TFT training on UCI electricity dataset is coded up per examples using Lightning, train/val datasets fits well within system memory.
Actual behavior
Throughout training, cpu RAM consumption steadily grows until it hits OOM (96gb) and kernel crashes. I have tried disabling logging everywhere I know how within Pytorch-Forecasting and Lightning to no avail. GPU ram consumption stays steady at ~1.2gb
Code to reproduce the problem
max_prediction_length = 24
max_encoder_length = 7*24
batch_size = 64
num_workers=8
split_time_idx = 30000
train_data = TimeSeriesDataSet(data=df[lambda x: x.time_idx < split_time_idx],
time_idx='time_idx',
target='demand',
group_ids=['group_id'],
min_encoder_length=max_encoder_length,
max_encoder_length=max_encoder_length,
min_prediction_length=max_prediction_length,
max_prediction_length=max_prediction_length,
static_categoricals=['group_id'],
time_varying_known_reals=['time_idx', 'hour', 'weekday', 'day', 'month'],
time_varying_unknown_reals=['demand'],
target_normalizer=GroupNormalizer(groups=['group_id'], transformation='softplus'),
add_relative_time_idx=True,
add_target_scales=True,
randomize_length=False)
val_data = TimeSeriesDataSet.from_dataset(train_data, df[lambda x: x.time_idx >= split_time_idx], stop_randomization=True, predict=False)
train_dataloader = train_data.to_dataloader(train=True, batch_size=batch_size, num_workers=num_workers)
val_dataloader = val_data.to_dataloader(train=False, batch_size=batch_size, num_workers=num_workers)
tft = TemporalFusionTransformer.from_dataset(train_data,
learning_rate=.001,
hidden_size=160,
hidden_continuous_size=160,
attention_head_size=4,
dropout=.1,
output_size=output_size,
loss=quantile_loss,
log_interval=-1,
reduce_on_plateau_patience=4)
Could you find any answers
In my experience with this problem, I've found that the memory consumption stops increasing and goes down (a little bit) after an epoch is complete. Therefore, in my circumstance, if I raise the batch size to an arbitrarily large amount, the problem goes away. This is more of a hack than a fix though.
I've also found that if I set the num_workers to 0 and leave the batch size as is, then I don't run into running out of memory. Again this is more of a hack than a fix.
It "feels" like a memory leak. I'm just not sure if just the dataloader is the problem, or if it's happening during training. I'm also not sure if pytorch-forecasting is the problem or the bug exists with pytorch-lightning or torch itself.
I'd love to try running the TFT on the UCI electricity dataset. Did you do anything prior to the code you posted? Did you just download a CSV file, read it into pandas and store it into the variable df?
I have run with the earlier version of pytorch-forecasting and it is working seeamlessly. These problems could be from upgrade of torch or pytorchlightning.
@sairamtvv Can you please share the versions of torch, forecasting, lightning, and other useful packages you are using ?
@jon-huff .... I'm experiencing a similar issue with 'pytorch_forecasting' version 1.0.0 (latest version). Everything was working seamlessly fine two months ago, but when I recently ran the same code in the same environment, I noticed a significant increase in RAM consumption. Specifically, whole RAM pf my system (64GB) is getting consumed even when a single epoch is completed.