pytorch-forecasting icon indicating copy to clipboard operation
pytorch-forecasting copied to clipboard

Multiple Target Prediction Plotting Bug

Open terbed opened this issue 1 year ago • 2 comments

Description

When calling the plot_prediction() function in PytorchForecasting with multiple targets, the function reuses the same axes for each target. This behavior results in overlapped plots for different targets, rather than separate plots for each target. This pull request fixes this issue.

Issue

#1314

The faulty code part

def plot_prediction(
    self,
    x: Dict[str, torch.Tensor],
    out: Dict[str, torch.Tensor],
    idx: int = 0,
    add_loss_to_title: Union[Metric, torch.Tensor, bool] = False,
    show_future_observed: bool = True,
    ax=None,
    quantiles_kwargs: Dict[str, Any] = {},
    prediction_kwargs: Dict[str, Any] = {},
) -> plt.Figure:

    #...

    # for each target, plot
    figs = []
    for y_raw, y_hat, y_quantile, encoder_target, decoder_target in zip(
        y_raws, y_hats, y_quantiles, encoder_targets, decoder_targets
    ):
        # ...

        # create figure
        if ax is None:
            fig, ax = plt.subplots()
        else:
            fig = ax.get_figure()
        
        # ...

        figs.append(fig)
    
    return figs

Expected behavior: Each target should be plotted on a separate figure.

Actual behavior: All targets are plotted on the same figure, resulting in overlapped plots.

Solution

In the above snippet, the variable ax should be updated within the loop over targets but instead after the first target, the same ax is reused (as ax is no longer None). The proposed issue fix is:

    def plot_prediction(
        self,
        x: Dict[str, torch.Tensor],
        out: Dict[str, torch.Tensor],
        idx: int = 0,
        add_loss_to_title: Union[Metric, torch.Tensor, bool] = False,
        show_future_observed: bool = True,
        ax=None,
        quantiles_kwargs: Dict[str, Any] = {},
        prediction_kwargs: Dict[str, Any] = {},
    ) -> plt.Figure:

        # ...
        # for each target, plot
        figs = []
        ax_provided = ax is not None
        for y_raw, y_hat, y_quantile, encoder_target, decoder_target in zip(
            y_raws, y_hats, y_quantiles, encoder_targets, decoder_targets
        ):

            # ...
            # create figure
            if (ax is None) or (not ax_provided):
                fig, ax = plt.subplots()
            else:
                fig = ax.get_figure()

Bonus

Corrected mistakes in documentation. The encoder's log1p transformation is incorrectly called logp1 in the documentation. #1247

terbed avatar May 27 '23 10:05 terbed

Is this repo not maintained?

terbed avatar Jun 14 '23 08:06 terbed

Codecov Report

Patch coverage: 33.33% and project coverage change: -0.08% :warning:

Comparison is base (9995d0a) 90.13% compared to head (120f3e4) 90.05%. Report is 2 commits behind head on master.

:exclamation: Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1317      +/-   ##
==========================================
- Coverage   90.13%   90.05%   -0.08%     
==========================================
  Files          30       30              
  Lines        4712     4716       +4     
==========================================
  Hits         4247     4247              
- Misses        465      469       +4     
Flag Coverage Δ
cpu 90.05% <33.33%> (-0.08%) :arrow_down:
pytest 90.05% <33.33%> (-0.08%) :arrow_down:

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Changed Coverage Δ
pytorch_forecasting/data/encoders.py 87.25% <ø> (ø)
pytorch_forecasting/models/base_model.py 87.77% <33.33%> (-0.42%) :arrow_down:

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov-commenter avatar Sep 10 '23 21:09 codecov-commenter