pytorch-forecasting
pytorch-forecasting copied to clipboard
Multiple Target Prediction Plotting Bug
Description
When calling the plot_prediction() function in PytorchForecasting with multiple targets, the function reuses the same axes for each target. This behavior results in overlapped plots for different targets, rather than separate plots for each target. This pull request fixes this issue.
Issue
#1314
The faulty code part
def plot_prediction(
self,
x: Dict[str, torch.Tensor],
out: Dict[str, torch.Tensor],
idx: int = 0,
add_loss_to_title: Union[Metric, torch.Tensor, bool] = False,
show_future_observed: bool = True,
ax=None,
quantiles_kwargs: Dict[str, Any] = {},
prediction_kwargs: Dict[str, Any] = {},
) -> plt.Figure:
#...
# for each target, plot
figs = []
for y_raw, y_hat, y_quantile, encoder_target, decoder_target in zip(
y_raws, y_hats, y_quantiles, encoder_targets, decoder_targets
):
# ...
# create figure
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.get_figure()
# ...
figs.append(fig)
return figs
Expected behavior: Each target should be plotted on a separate figure.
Actual behavior: All targets are plotted on the same figure, resulting in overlapped plots.
Solution
In the above snippet, the variable ax should be updated within the loop over targets but instead after the first target, the same ax is reused (as ax is no longer None). The proposed issue fix is:
def plot_prediction(
self,
x: Dict[str, torch.Tensor],
out: Dict[str, torch.Tensor],
idx: int = 0,
add_loss_to_title: Union[Metric, torch.Tensor, bool] = False,
show_future_observed: bool = True,
ax=None,
quantiles_kwargs: Dict[str, Any] = {},
prediction_kwargs: Dict[str, Any] = {},
) -> plt.Figure:
# ...
# for each target, plot
figs = []
ax_provided = ax is not None
for y_raw, y_hat, y_quantile, encoder_target, decoder_target in zip(
y_raws, y_hats, y_quantiles, encoder_targets, decoder_targets
):
# ...
# create figure
if (ax is None) or (not ax_provided):
fig, ax = plt.subplots()
else:
fig = ax.get_figure()
Bonus
Corrected mistakes in documentation. The encoder's log1p
transformation is incorrectly called logp1
in the documentation.
#1247
Is this repo not maintained?
Codecov Report
Patch coverage: 33.33%
and project coverage change: -0.08%
:warning:
Comparison is base (
9995d0a
) 90.13% compared to head (120f3e4
) 90.05%. Report is 2 commits behind head on master.
:exclamation: Your organization needs to install the Codecov GitHub app to enable full functionality.
Additional details and impacted files
@@ Coverage Diff @@
## master #1317 +/- ##
==========================================
- Coverage 90.13% 90.05% -0.08%
==========================================
Files 30 30
Lines 4712 4716 +4
==========================================
Hits 4247 4247
- Misses 465 469 +4
Flag | Coverage Δ | |
---|---|---|
cpu | 90.05% <33.33%> (-0.08%) |
:arrow_down: |
pytest | 90.05% <33.33%> (-0.08%) |
:arrow_down: |
Flags with carried forward coverage won't be shown. Click here to find out more.
Files Changed | Coverage Δ | |
---|---|---|
pytorch_forecasting/data/encoders.py | 87.25% <ø> (ø) |
|
pytorch_forecasting/models/base_model.py | 87.77% <33.33%> (-0.42%) |
:arrow_down: |
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.