pairplot refactoring
This PR is my (delayed) contribution to the 2025 hackathon, where I tried to resolve issues with the user interface of the widely used pairplot function, as in https://github.com/sbi-dev/sbi/issues/1425
The PR addresses the following main issue:
- The default parameters for kwargs were hardcoded in functions such as
_get_default_diag_kwargs, which returned dictionaries. I replaced those functions with dataclasses that contain the default values. These dataclasses can easily be converted to dictionaries by callingdict(FigKwargs)so few changes are required internally and users can still use the standard way of passing kwargs as dictionaries. But the dataclasses are considered more Pythonic, they expose their internals more clearly to the user and in the future they could be passed topairplotinstead of dictionaries to specify keyword arguments. Although I currently don't know how to best make them available to the user, since they need to be explicitly imported right now. But that could be part of a future PR. - The
samplespassed by the user were converted internally to a list of numpy arrays. Instead I now callnp.ndarray(samples), which creates a copy if necessary but changes nothing if samples are already a numpy array. IMO passing samples as ndarray should be strongly encouraged and it should either be an np.ndarray or a torch.Tensor. But for now lists are also supported for user flexibility. diag_kwargupper_kwargslower_kwargsdiag,upperandlowerall accept lists most likely with the intention that the user could chose a different plot type and different parameters for each plot. However, this was actually not working in the main branch. Instead only the first entry was used. I added user warnings that warn the user about this when they pass a list for any of these arguments. We should consider if this feature is actually desired. If not, the code could be massively simplified.- The way kwargs are passed has caused confusion, because they are passed as a nested dictionary
{'mpl_kwargs': {}}, where only the entries inmpl_kwargsare actually passed to matplotlib. So{'bins':10, 'mpl_kwargs': {}}, the'bins'entry was siltently ignored. Instead,{'mpl_kwargs': {'bins':10}}would be required. If any entries in anykwargsis known to be ignored downstream, the user receives a warning about his issue. This is achieved by comparing the user provided dict with the parameter defined in the default dataclasses.
There are still many issues with pairplot.py IMO and I am open to describing them in separate issues and continue work on those.
Codecov Report
:x: Patch coverage is 76.69173% with 62 lines in your changes missing coverage. Please review.
:white_check_mark: Project coverage is 78.80%. Comparing base (1757616) to head (035086a).
:warning: Report is 58 commits behind head on main.
| Files with missing lines | Patch % | Lines |
|---|---|---|
| sbi/analysis/plot.py | 76.69% | 62 Missing :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## main #1529 +/- ##
==========================================
- Coverage 86.01% 78.80% -7.22%
==========================================
Files 135 135
Lines 10751 10929 +178
==========================================
- Hits 9248 8613 -635
- Misses 1503 2316 +813
| Flag | Coverage Δ | |
|---|---|---|
| unittests | 78.80% <76.69%> (-7.22%) |
:arrow_down: |
Flags with carried forward coverage won't be shown. Click here to find out more.
| Files with missing lines | Coverage Δ | |
|---|---|---|
| sbi/analysis/plot.py | 69.37% <76.69%> (+0.79%) |
:arrow_up: |
Thank you for the detailed review of the PR and sorry for the messy commented out code. I've gone through all the comments and made the suggested changes. Many of the # type: ignor flags are now unnecessary as you suspected. Most importantly I also fixed the issues that broke plot_test.py
I worked through https://github.com/sbi-dev/sbi/blob/main/docs/advanced_tutorials/17_plotting_functionality.ipynb and some of the functionality that works in the main branch is broken in my PR. I also might have misunderstood the point of passing different plot types in upper. I though the point was to have different plots in different places of the upper diagonal (that's the thing that doesn't work), but I now see that the actual purpose is to have two types of plots overlayed in all places. That feature was never clear to me in any of the other notebooks. It might be a good idea for me to attend an office hour to clarify the intended functionality. Either way it will take me some time to understand what's going on. Will be back to you.
Hi @danielmk! Is there an update from your end on the status of this PR, or any input from our end that could help at this stage?
Sorry for the delay on this. During the hackathon I underestimated the complexity of the pairplot function. Salvaging this pull request by fixing the features I broke, would be too much work. Instead, I suggest making smaller pull requests that target acute issues one by one. The most important would probably be to warn the user if their kwargs were ignored. I can open a new issue for this and start working on the pull request.
I want to apologize for not getting to this @gmoss13. After the hackathon, I did not have the bandwidth. I see there is now a much better PR for pairplot refactoring here: https://github.com/sbi-dev/sbi/pull/1631
Since I won't be able to get to this it might make sense to close here and continue there. Sorry again.
Thanks for the update and no worries at all @danielmk!
Yes, as part of the Google Summer of Code, @abelaba was able to start working on this. Initially, we aimed to include your angle taken in this PR as well, but in the end, starting from scratch seemed easier. We will try to consolidate both approaches and might cherry-pick parts of your contributing and discussions here. We will let you know how it goes.
Thank you for the time you put into this!! 🙏