darts
darts copied to clipboard
NBEATS Optimal HyperParameters for large datasets
Hi again!
Currently, I am training an NBEATS model on the following:
- 10,000+ Timeseries of varying lengths (200-7000+) all with 4 past_covariates that match in timesteps. (quite a feat) I plan to use this model in a one-shot learning application.
Here are the model hyperparameters for the NBEATS model:
multiseries_multicov_nbeats = NBEATSModel(
input_chunk_length=30,
output_chunk_length=7,
generic_architecture=True,
num_stacks=10,
num_blocks=3,
num_layers=4,
layer_widths=512,
n_epochs=100,
nr_epochs_val_period=1,
batch_size=1600,
model_name="nbeats_run",
force_reset=True,
pl_trainer_kwargs={
"accelerator": "gpu",
"gpus": [0]
},
save_checkpoints=True,
random_state = 42,
)
multiseries_multicov_nbeats.fit(series = time_series_train,
past_covariates=covariates_train)
I have changed the following:
batch_size=800
--> 1600
num_blocks=1
--> 3
It currently takes me ~55 hours to train a single model, so experimentation is lengthy in time. :cry:
- Any other optimal hyper-parameters to leverage for such large datasets from your experience?
- I assume grid search with Nbeats and multi covariates is not possible yet, correct?
Any and all recommendations would be most appreciated :+1: !
Thank you
- Adding an early stop can help cut down the training time. Is there a validation set you can use?
Absolutely! I split all datasets into train/Val. Thanks for the suggestion. I will supply the entire covariate time series to match with the validation.
How does one accomplish early stopping with Nbeats?
You use the PyTorch lightning early stop callback
my_stopper = EarlyStopping(
monitor="val_loss",
patience=5,
min_delta=0.05,
mode='min',
)
pl_trainer_kwargs={"callbacks": [my_stopper]}
model = NBEATSModel(...,
pl_trainer_kwargs=pl_trainer_kwargs)
model.fit(
series=train,
val_series=val,
past_covariates=train_covariates
val_past_covariates=val_covariates
)
This is great @gdevos010 Thank you!
I guess I will set up a dictionary of multiple hyperparameters to experiment.
I also recommended you take a look at performance tuning recommendations for torch models here: https://unit8co.github.io/darts/userguide/torch_forecasting_models.html#performance-recommendations
Awesome will do! Thank you @hrzn