darts
darts copied to clipboard
[BUG] NBEATSModel | Progress bar is not working
Describe the bug Hi Team
I use the NBEATS model and set the enable_progress_bar=True in pl_trainer_kwargs. However, when I start fitting the model. The progress bar doesn't work as expected which makes me hard to tell whether the model training has started or not.
To Reproduce model_nbeats = NBEATSModel( input_chunk_length=30, output_chunk_length=30, generic_architecture=True, num_stacks=10, num_blocks=1, num_layers=4, layer_widths=32, n_epochs=2, nr_epochs_val_period=1, batch_size=512, model_name="nbeats_run_acu", pl_trainer_kwargs={"log_every_n_steps":4, "accelerator": "cpu", "gpus": None, "auto_select_gpus": False, "enable_progress_bar" : True, "enable_model_summary" : True}, work_dir="/path_for_work_dir/", log_tensorboard="/path_for_tensorboard", force_reset=True, show_warnings=True, random_state=42 )
model_nbeats.fit(test_train, val_series=test_val, num_loader_workers=80)
Expected behavior
I was expecting to see a progress bar moving and track the progress of each epoch, but I only get this:
System (please complete the following information):
- Python version: [e.g. 3.8.10]
- darts version [e.g. 0.17.1]
Additional context Add any other context about the problem here.
My first guess would be that is comes from num_worker_loaders=80
. Can you try it once withouth using the kwarg?
I tried the one without the using that kwarg it is the same situation. Do you know what else could cause this issue?
I set the kwarg to num_worker_loaders=80 is due to the size of my data and I have 80 CPUs in my compute so I follow the API reference to try to increase the data loading efficiency.
Thanks for getting back to me on this issue.
Running without a GPU will be slow. To know if it is working, I would reduce your batch size from 512 to 2, which will reduce the time to report one iteration.
Hey @gdevos010
Thanks for getting back to me. I set the batch size to 2 and still get the similar result:
It seems to be running, but the progress bar can't display correctly?
By the way, I'm running this on Azure databricks with Runtime version = 10.3 ML.
It might be that training is starting but each batch takes too long (even with a size of 2). Could you try setting max_samples_per_ts=1
when calling fit()
?
I also experience this issue with the TCNModel
Same issue with TFTModel. It works when using RichProgressBar, but then I can't save the model to pickle anymore :/
https://github.com/unit8co/darts/issues/962
I might have found a potential fix to the original issue: could anyone try out this solution?
Closing for now - please re-open if the issue persists.