pytorch-forecasting
pytorch-forecasting copied to clipboard
plot_prediction errorbar error
- PyTorch-Forecasting version: 0.10.3
- PyTorch version: 1.12.1
- Python version: 3.10.6
- Operating System: Amazon Linux 2
Expected behavior
Running TFT tutorial and change the max_prediction_length to 1 encounters error, while max_prediction_length > 1 works fine. 0.10.1 works fine with max_prediction_length = 1
Getting the following error:
ValueError Traceback (most recent call last) Input In [18], in <cell line: 2>() 1 # fit network ----> 2 trainer.fit( 3 tft, 4 train_dataloaders=train_dataloader, 5 val_dataloaders=val_dataloader, 6 )
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:696, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
677 r"""
678 Runs the full optimization routine.
679
(...)
693 datamodule: An instance of :class:~pytorch_lightning.core.datamodule.LightningDataModule.
694 """
695 self.strategy.model = model
--> 696 self._call_and_handle_interrupt(
697 self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
698 )
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:650, in Trainer._call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
648 return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
649 else:
--> 650 return trainer_fn(*args, **kwargs)
651 # TODO(awaelchli): Unify both exceptions below, where KeyboardError doesn't re-raise
652 except KeyboardInterrupt as exception:
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:735, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 731 ckpt_path = ckpt_path or self.resume_from_checkpoint 732 self._ckpt_path = self.__set_ckpt_path( 733 ckpt_path, model_provided=True, model_connected=self.lightning_module is not None 734 ) --> 735 results = self._run(model, ckpt_path=self.ckpt_path) 737 assert self.state.stopped 738 self.training = False
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1166, in Trainer._run(self, model, ckpt_path) 1162 self._checkpoint_connector.restore_training_state() 1164 self._checkpoint_connector.resume_end() -> 1166 results = self._run_stage() 1168 log.detail(f"{self.class.name}: trainer tearing down") 1169 self._teardown()
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1252, in Trainer._run_stage(self) 1250 if self.predicting: 1251 return self._run_predict() -> 1252 return self._run_train()
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1283, in Trainer._run_train(self) 1280 self.fit_loop.trainer = self 1282 with torch.autograd.set_detect_anomaly(self._detect_anomaly): -> 1283 self.fit_loop.run()
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:200, in Loop.run(self, *args, **kwargs) 198 try: 199 self.on_advance_start(*args, **kwargs) --> 200 self.advance(*args, **kwargs) 201 self.on_advance_end() 202 self._restarting = False
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:271, in FitLoop.advance(self) 267 self._data_fetcher.setup( 268 dataloader, batch_to_device=partial(self.trainer._call_strategy_hook, "batch_to_device", dataloader_idx=0) 269 ) 270 with self.trainer.profiler.profile("run_training_epoch"): --> 271 self._outputs = self.epoch_loop.run(self._data_fetcher)
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:201, in Loop.run(self, *args, **kwargs) 199 self.on_advance_start(*args, **kwargs) 200 self.advance(*args, **kwargs) --> 201 self.on_advance_end() 202 self._restarting = False 203 except StopIteration:
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py:241, in TrainingEpochLoop.on_advance_end(self) 239 if should_check_val: 240 self.trainer.validating = True --> 241 self._run_validation() 242 self.trainer.training = True 244 # update plateau LR scheduler after metrics are logged
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py:299, in TrainingEpochLoop._run_validation(self) 296 self.val_loop._reload_evaluation_dataloaders() 298 with torch.no_grad(): --> 299 self.val_loop.run()
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:200, in Loop.run(self, *args, **kwargs) 198 try: 199 self.on_advance_start(*args, **kwargs) --> 200 self.advance(*args, **kwargs) 201 self.on_advance_end() 202 self._restarting = False
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py:155, in EvaluationLoop.advance(self, *args, **kwargs) 153 if self.num_dataloaders > 1: 154 kwargs["dataloader_idx"] = dataloader_idx --> 155 dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs) 157 # store batch level output per dataloader 158 self._outputs.append(dl_outputs)
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:200, in Loop.run(self, *args, **kwargs) 198 try: 199 self.on_advance_start(*args, **kwargs) --> 200 self.advance(*args, **kwargs) 201 self.on_advance_end() 202 self._restarting = False
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:143, in EvaluationEpochLoop.advance(self, data_fetcher, dl_max_batches, kwargs) 140 self.batch_progress.increment_started() 142 # lightning module methods --> 143 output = self._evaluation_step(**kwargs) 144 output = self._evaluation_step_end(output) 146 self.batch_progress.increment_processed()
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:240, in EvaluationEpochLoop._evaluation_step(self, **kwargs) 229 """The evaluation step (validation_step or test_step depending on the trainer's state). 230 231 Args: (...) 237 the outputs of the step 238 """ 239 hook_name = "test_step" if self.trainer.testing else "validation_step" --> 240 output = self.trainer._call_strategy_hook(hook_name, *kwargs.values()) 242 return output
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1704, in Trainer._call_strategy_hook(self, hook_name, *args, **kwargs) 1701 return 1703 with self.profiler.profile(f"[Strategy]{self.strategy.class.name}.{hook_name}"): -> 1704 output = fn(*args, **kwargs) 1706 # restore current_fx when nested context 1707 pl_module._current_fx_name = prev_fx_name
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py:370, in Strategy.validation_step(self, *args, **kwargs) 368 with self.precision_plugin.val_step_context(): 369 assert isinstance(self.model, ValidationStep) --> 370 return self.model.validation_step(*args, **kwargs)
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_forecasting/models/base_model.py:420, in BaseModel.validation_step(self, batch, batch_idx) 418 x, y = batch 419 log, out = self.step(x, y, batch_idx) --> 420 log.update(self.create_log(x, y, out, batch_idx)) 421 return log
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_forecasting/models/temporal_fusion_transformer/init.py:520, in TemporalFusionTransformer.create_log(self, x, y, out, batch_idx, **kwargs) 519 def create_log(self, x, y, out, batch_idx, **kwargs): --> 520 log = super().create_log(x, y, out, batch_idx, **kwargs) 521 if self.log_interval > 0: 522 log["interpretation"] = self._log_interpretation(out)
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_forecasting/models/base_model.py:469, in BaseModel.create_log(self, x, y, out, batch_idx, prediction_kwargs, quantiles_kwargs) 467 self.log_metrics(x, y, out, prediction_kwargs=prediction_kwargs) 468 if self.log_interval > 0: --> 469 self.log_prediction( 470 x, out, batch_idx, prediction_kwargs=prediction_kwargs, quantiles_kwargs=quantiles_kwargs 471 ) 472 return {}
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_forecasting/models/base_model.py:717, in BaseModel.log_prediction(self, x, out, batch_idx, **kwargs) 715 log_indices = [0] 716 for idx in log_indices: --> 717 fig = self.plot_prediction(x, out, idx=idx, add_loss_to_title=True, **kwargs) 718 tag = f"{self.current_stage} prediction" 719 if self.training:
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_forecasting/models/temporal_fusion_transformer/init.py:711, in TemporalFusionTransformer.plot_prediction(self, x, out, idx, plot_attention, add_loss_to_title, show_future_observed, ax, **kwargs) 694 """ 695 Plot actuals vs prediction and attention 696 (...) 707 plt.Figure: matplotlib figure 708 """ 710 # plot prediction as normal --> 711 fig = super().plot_prediction( 712 x, 713 out, 714 idx=idx, 715 add_loss_to_title=add_loss_to_title, 716 show_future_observed=show_future_observed, 717 ax=ax, 718 **kwargs, 719 ) 721 # add attention on secondary axis 722 if plot_attention:
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/pytorch_forecasting/models/base_model.py:832, in BaseModel.plot_prediction(self, x, out, idx, add_loss_to_title, show_future_observed, ax, quantiles_kwargs, prediction_kwargs) 830 else: 831 quantiles = torch.tensor([[y_quantile[0, i]], [y_quantile[0, -i - 1]]]) --> 832 ax.errorbar( 833 x_pred, 834 y[[-n_pred]], 835 yerr=quantiles - y[-n_pred], 836 c=pred_color, 837 capsize=1.0, 838 ) 840 if add_loss_to_title is not False: 841 if isinstance(add_loss_to_title, bool):
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/matplotlib/init.py:1423, in _preprocess_data.
File ~/SageMaker/.persisted_conda/ksquant/lib/python3.10/site-packages/matplotlib/axes/_axes.py:3587, in Axes.errorbar(self, x, y, yerr, xerr, fmt, ecolor, elinewidth, capsize, barsabove, lolims, uplims, xlolims, xuplims, errorevery, capthick, **kwargs) 3584 res = np.zeros(err.shape, dtype=bool) # Default in case of nan 3585 if np.any(np.less(err, -err, out=res, where=(err == err))): 3586 # like err<0, but also works for timedelta and nan. -> 3587 raise ValueError( 3588 f"'{dep_axis}err' must not contain negative values") 3589 # This is like 3590 # elow, ehigh = np.broadcast_to(...) 3591 # return dep - elow * ~lolims, dep + ehigh * ~uplims 3592 # except that broadcast_to would strip units. 3593 low, high = dep + np.row_stack([-(1 - lolims), 1 - uplims]) * err
ValueError: 'yerr' must not contain negative values
Also experiencing this issue
same issue
I had the same issue and I solved it by downgrading the version of matplotlib for 3.7.1 to 3.4.3 It seems that the yerr parameters in the errorbar function doesn't accept negative values anymore. So either you take the absolute value of your dataset that goes into yerr, either you can downgrade the version of matplotlib.
Hope this will help
I don't get it, why does it have to plot while training ? can we deactivate this ?