sbi icon indicating copy to clipboard operation
sbi copied to clipboard

Refactor pairplot function

Open gmoss13 opened this issue 9 months ago • 11 comments

100% of survey responders stated they use the pairplot function, however right now the API is messy and opaque. It would be great to keep the plot aesthetic the same, but simplify the interface.

gmoss13 avatar Mar 12 '25 18:03 gmoss13

I could work on this during the hackathon.

danielmk avatar Mar 14 '25 08:03 danielmk

I've looked in detail at the pairplot function. I went through all the notebooks and tried to break the interface. The interface is actually pretty good and flexible. The only interface issue I found so far is that prepare_for_plot accepts an np.ndarray for limits, but passing and np.ndarray currently causes an error which is relatively easy to fix. That's probably why the actual pairplot function does not accept an np.ndarray for limits. If I fix the error, pairplot could accept an ndarray for limits in the future.

While the API is good, the internals are a bit over-engineered. A lot of code is dedicated to converting the samples to a list of ndarrays. I didn't understand yet why this emphasis on a list, instead of an np.ndarray. I think internally using an ndarray would be better and I would even go so far as to not accept lists of tensors or arrays as the input for samples. Mainly because it is easier to confuse rows and columns in a list.

If samples: Union[np.ndarray, torch.Tensor], we could get rid of a lot of code by calling np.asarray(samples) at the top of the function. If samples are always an np.ndarray a lot of code that covers the list of arrays case becomes obsolete.

I'll be working on a PR along those lines but you can let me know if there is a good reason to use a list of arrays I am missing.

danielmk avatar Mar 18 '25 05:03 danielmk

Thanks @danielmk! I agree that it is cumbersome to convert the samples to a list of ndarrays - this is not necessary. A couple of additional changes that might be nice for a pairplot PR:

  1. Refactor from "stringly"-typed to "strongly"-typed (came up in discussion after the talk from @janosg today) - a lot of the options that are passed to pairplot are passed through dictionaries - replacing these with dataclasses can be more convenient to the user e.g. via autocomplete.
  2. the limits argument should be able to take in (xmin,xmax) where xmin is the tensor/array of the min values for all dimensions, and similarly for xmax - this is the pytorch convention for defining min and max, as opposed to a list of [xmin,xmax] per dimension.

With any changes to the API, we should be backwards compatible - i.e. if we pass lists as in the current API, we should raise a warning but this should still be allowed. With the code that is already deprecated, I would keep this code for now, but add that this would no longer be supported from v0.26.0 (which would be 2 releases from now).

gmoss13 avatar Mar 18 '25 09:03 gmoss13

Thanks @danielmk !

Two comments from my side:

  • I noticed in my own usage that since the update of the pairplot function during the last hackathon, it's quite difficult to use the old kwargs, passing them sometimes results in no effect at all and I end up explicitly using pairplot_dep instead. Maybe there is a way to have a more explicit interface to the old kwargs or find a unified interface old a new kwargs?
  • @anastasiakrouglova found a nice lightweight way to display kw argument options with users. Can you share this here please @anastasiakrouglova thanks 🙏

janfb avatar Mar 18 '25 10:03 janfb

+1 on having a good way to display kwargs to the user!

I think in principle using the kwargs from pairplot_dep on pairplot will work the same as calling pairplot_dep directly. However, a pretty frustrating issue for me is the use of several kwarg groups, namelyfig_kwargs, diag_kwargs, upper_kwargs, and lower_kwargs. For example, if I want to change the number of bins used to plot the diagonal, I need to specify:

diag_kwargs={"mpl_kwargs": {"bins": 10}}

If I try to specify something seemingly reasonable, llke fig_kwargs:{"bins":50}, I get no warning or error but nothing changes (as a side note, the same for upper/lower_kwargs needs to be specified separately). Since kwarg options are not covered in the docs, to figure out why this is not working, I need to dive into the codebase.

As another common example, to change the colors of points plotted, we need to specify fig_kwargs={"points_colors": "red"} - any reasonable typo such as point_colors or point_color will again raise no error or warning but not change anything. This is not new to the recent pairplot function, as this was also true for pairplot_dep. However, with the several kwarg groups, there is more room for error.

If there is a lightweight solution to display the kwarg options to the user, I think that would be a great solution. Another low-cost partial solution would be to have a pairplot_help function that the user can call to see what kwargs they can change and where, and maybe also when the user specifies some kwarg that is not in the default_kwargs (and will thus probably be silently ignored), a warning should be raised.

Happy for comments/thoughts on this.

gmoss13 avatar Mar 18 '25 20:03 gmoss13

Yes I can see how having multiple kwarg style argument introduces problems. Hard to say how to improve on this while maintaining backwards compatibility. One direction I am investigating: I could convert the functions that return a dictionary with kwargs defaults to dataclass to alleviate the stringly typed issue and define the __getitem__ magic method such that it can also behave like a dictionary when the context requires for backwards compatibility. Not sure if matplotlib will like that, looking into it.

danielmk avatar Mar 19 '25 03:03 danielmk

Here is the strategy: I define a FigKwargs dataclass that has exactly the same default values as in def _get_default_fig_kwargs() -> Dict. That makes FigKwargs().__dict__ functionally identical to _get_default_fig_kwargs() (the only reason they are not actually == is that both create a mpl.ticker.FormatStrFormatter("%g") at initialization/execution, which have different places in memory. But they do the same thing). So we can internally use FigKwargs().__dict__ and users can still pass a dictionary for fig_kwargs like so:

_ = pairplot(samples,
             limits=[[-2, 2], [-2, 2], [-2, 2]],
             labels=[r"$\theta_1$", r"$\theta_2$", r"$\theta_3$"],
             fig_kwargs={'title': "Custom Title", 'title_format': {"fontsize": 30}}
            )

But the inclined user that wants to avoid stringly typing and wants to take advantage of type hints and autocomplete features can use the FigKwargs dataclass like so:

_ = pairplot(samples,
             limits=[[-2, 2], [-2, 2], [-2, 2]],
             labels=[r"$\theta_1$", r"$\theta_2$", r"$\theta_3$"],
             fig_kwargs=FigKwargs(title="Custom Title", title_format={"fontsize": 30})
            )

One advantage is that FigKwargs(title="Custom Title", title_format={"fontsize": 30}, bins=10) will raise an error. But if the user insists that the kwarg bins should be passed to mpl for whatever reason can still do:

FigKwargs(title="Custom Title", title_format={"fontsize": 30})
FigKwargs['bins'] =10

I have this running and will next figure out how to do this for the other kwargs. This is how the dataclass looks btw. It's not as pythonic as I hoped, because dataclasses don't allow mutable defaults. But it looks much better than _get_default_fig_kwargs() imo:

@dataclass
class FigKwargs:
    legend: Optional[List[str]] = None
    legend_kwargs: Dict = field(default_factory=dict) # type: ignore
    points_labels: list = field(default_factory=lambda: [f"points_{idx}" for idx in range(10)])  
    samples_labels: list = field(default_factory=lambda: [f"samples_{idx}" for idx in range(10)])
    samples_colors: list = field(default_factory=lambda: plt.rcParams["axes.prop_cycle"].by_key()["color"][0::2])
    points_colors: list = field(default_factory=lambda: plt.rcParams["axes.prop_cycle"].by_key()["color"][1::2])
    tickformatter: mpl.ticker.FormatStrFormatter = mpl.ticker.FormatStrFormatter("%g") # type: ignore
    tick_labels: Optional[Dict] = None
    points_diag: dict = field(default_factory=dict)
    points_offdiag: dict = field(default_factory=lambda: {"marker": ".", "markersize": 10,})
    fig_bg_colors: dict = field(default_factory=lambda: {"offdiag": None, "diag": None, "lower": None})
    fig_subplots_adjust: dict =  field(default_factory=lambda: {"top": 0.9,})
    subplots: dict = field(default_factory=dict)
    despine:  dict = field(default_factory=lambda: {"offset": 5,})
    title: Optional[str] = None
    title_format: dict = field(default_factory=lambda: {"fontsize": 16})
    x_lim_add_eps: float = 1e-5
    square_subplots: bool = True

    def __getitem__(self, key):
        return eval(f'self.{key}')

    def __setitem__(self, key, item):
        self.__dict__[key] = item

danielmk avatar Mar 19 '25 07:03 danielmk

this is great! 🚀

janfb avatar Mar 19 '25 10:03 janfb

For the diagonal plots, if diag=="kde", diag_kwargs={'bins': 50} changes the bins as intended. But if diag=='hist', diag_kwargs={'bins': 50} silently has no effect on the bins. Instead, diag_kwargs={'mpl_kwargs': {'bins': 50}} is required to change the number of bins. That is confusing and a likely source of API complaints. I am trying to understand if 'mpl_kwargs' is a common matplotlib pattern we should respect and to what extent we should work around it to make the API more intuitive.

EDIT: I think the main issue is that diag_kwargs, upper_kwargs and lower_kwargs are not actually passed to matplotlib. Only the {'mpl_kwargs': {}} dict entries are passed to matplotlib. Variables outside of 'mpl_kwargs'. I am tempted to split them up, but I am afraid adding even more parameters to pairplot makes the interface even more confusing, rather than easier. So I might just add something that warns the user if there are kwarg entriest that are ignored.

EDIT2: The kwargs system was introduces here as far as I can tell: https://github.com/sbi-dev/sbi/pull/1084 I think it was a good idea to separate kwargs for diag, upper and lower but mixing matplotlib kwargs with non-matplotlib kwargs introduces some confusion.

danielmk avatar Mar 20 '25 01:03 danielmk

Another issue: diag, upper & lower all accept lists a input, which suggests that they support distinct plot types for each parameter. However, this feature does not work currently. Instead, all plots are the type of the first list entry. That is, if diag=["kde", "kde", "hist"] all plots are of type "kde" and if diag=["hist", "kde", "kde"] all plots are of type "hist". I will raise a UserWarning about this for now, if a list is passed.

danielmk avatar Mar 20 '25 02:03 danielmk

Thanks for the update @danielmk!

I think the main issue is that diag_kwargs, upper_kwargs and lower_kwargs are not actually passed to matplotlib. Only the {'mpl_kwargs': {}} dict entries are passed to matplotlib.

You're right. Specifically with the issue of setting bins, the reason this is complicated is that pyplot.hist() takes bins as a kwarg, whereas pyplot.kde you have to use the bins to define the grid for plotting , and it is not a matplotlib kwarg. So in effect we have to somehow collapse the matplotlib kwargs and our own kwargs into one dictionary to pass these functions, which leads to confusing groupings. My suggested solution is to make "our" arguments for the individual plot functions (e.g. plot_2d_hist , plot_1d_kde, etc.) explicit, and only pass the kwargs to these functions that belong to matplotlib. This way we can have one set of mpl_kwargs we pass to pairplot (or a list, if we want them to be different per subplot, although I think that this is a very rare use case). The subplot plotting functions can then add to these kwargs before calling a matplotlib function if needed (e.g. to add bins to the kwargs). Let me know if this makes sense, I know it's a bit convoluted!

I will raise a UserWarning about this for now, if a list is passed.

Good catch! I think a UserWarning would be good for now, and later we can think whether we actually want to support different plot types for the different dims.

gmoss13 avatar Mar 20 '25 08:03 gmoss13