pytorch-forecasting
pytorch-forecasting copied to clipboard
Calling Trainer.fit() multiple times with different dataloaders on the same TFT model
PyTorch-Forecasting version: 0.9.2 PyTorch version: 1.12.0 PyTorch-Lightning: 1.5.10 Python version: 3.8.12 Operating System: Windows 10
I have a very big data set, which does not fully fit into one single training / validation dataloader, as the VRAM of my GPU is not big enough.
My thoughts were to call trainer.fit() multiple times. However, this seems to be easier said than done as trainer.fit() does not allow a second run. I already tried the solution from here.
Is there maybe another possibility?
I thought about not using the Pytorch Lightning Trainer and code the training/validation loop myself. However, I am not fully sure how to implement the optimizer, loss function and train_loss_value correctly into the loop, as the Pytorch Trainer does everything for me and the Documentation of Forecasting does not really mention the manual implementation and call of the optimizer and loss function.
Thanks for the help.