sbi icon indicating copy to clipboard operation
sbi copied to clipboard

pairplot: documentation, more options and tutorial

Open janfb opened this issue 11 months ago • 2 comments

The pairplot function is our main function for visualizing posteriors.

https://github.com/sbi-dev/sbi/blob/c3c2b6e142fb4c57d5599effc1f01f8222a37c57/sbi/analysis/plot.py#L280

It would be great to improve it bit:

  • adding more documentation
  • defaults that produce a decent default figure, even when plotting several distributions, e.g., prior vs posterior
  • more options, see https://github.com/sbi-dev/sbi/blob/c3c2b6e142fb4c57d5599effc1f01f8222a37c57/sbi/analysis/plot.py#L328-L330 (use bins="auto" as default)
  • a tutorial on how to use all the options 🥇

janfb avatar Feb 28 '24 10:02 janfb

I plan to have a look at it during the hackathon

famura avatar Mar 13 '24 09:03 famura

Also happy to help out here during the hackathon!

Matthijspals avatar Mar 13 '24 09:03 Matthijspals

It would be nice to have joint_plot for arbitrary pair of parameters:

def plot_joint(x, limits, ax, cmap="hot", label=None, points=[],
               xlabel="", ylabel="", add_corr=True):
    '''
    plot joint distribution of given samples
    
    Parameters
    ----------
    x : 2d-array
        samples from the distribution
    limits : list of tuples
        limits of the distribution
    '''
    
    density = gaussian_kde(x[:, [1, 0]].T, bw_method='scott')
    col = 1
    row = 0
    X, Y = np.meshgrid(
        np.linspace(limits[col][0], limits[col][1], 50,),
        np.linspace(limits[row][0], limits[row][1], 50))
    positions = np.vstack([X.ravel(), Y.ravel()])
    Z = np.reshape(density(positions).T, X.shape)
    # normalize Z
    Z = Z / np.max(Z)
    im = ax.imshow(Z, cmap=cmap,
              extent=(limits[col][0], limits[col][1],
                      limits[row][0], limits[row][1]),
                    origin="lower",
                    aspect="auto")
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.04)
    plt.colorbar(im, cax=cax, ax=ax, ticks=[0, 1])

    if len(points) > 0:
        ax.scatter([points[1]], [points[0]], s=150, color='#6cf086', marker='*')
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_xticks([0,1])
    ax.set_yticks([0,1])
    if add_corr:
        corr = np.corrcoef(x[:, 0], x[:, 1])[0, 1]
        ax.text(0.6, 0.88, r"$\rho=$"+f"{corr:.1f}", fontsize=25,
                transform=ax.transAxes, color="white")

Ziaeemehr avatar Mar 19 '24 16:03 Ziaeemehr

what exactly should joint_plot do?

michaeldeistler avatar Mar 19 '24 16:03 michaeldeistler

it actually extract the joint plot implemented on pairplot, give an option to user to select arbitrary pair of parameters. I had issue for preparing image for publications to have arbitrary panels arrangements.

Ziaeemehr avatar Mar 19 '24 16:03 Ziaeemehr

we have pairplot(..., subset=[0, 2, 3]), does this work?

michaeldeistler avatar Mar 19 '24 16:03 michaeldeistler

I am not sure, think it still give a triangle plot right? including diagonal (marginal) and offdiagonal (joint plots). What if user only need one panel of this triangle plot (joint plot of parameter i, j).

Ziaeemehr avatar Mar 19 '24 16:03 Ziaeemehr

ah, I see your point. I am not 100% convinced we need this, maybe we just leave this to users to implement themselves if really needed?

michaeldeistler avatar Mar 19 '24 16:03 michaeldeistler

no worries, that's just a suggestion. 👍

Ziaeemehr avatar Mar 19 '24 16:03 Ziaeemehr