Correctly compute limits in `prepare_for_plot` in case of samples with inf value
In issue #1037, the sampler produced samples that had inf value and the pair_plot function was failing when computing the x-axis max limit in the prepare_for_plot function.
TODO:
Change prepare_for_plot function to handle inf values.
@Matthijspals are you working on this, or should be un-assign you?
@Matthijspals are you working on this, or should be un-assign you?
This could be a quick fix - but what is expected behaviour here? Do we actually want to ignore nans everywhere in the plotting function? And what if a datapoint is nan for one dimension, but not the others - should it be included in marginal plots for the non nan dimensions?
I would probably be easier to just exclude the entire data point.
I think it makes sense to issue a warning about NaNs in the prepare_for_plot function and to add an flag exclude_invalid_data: bool=True to exclude the values in the downstream plotting functions. Similar to how we do it during inference:
https://github.com/sbi-dev/sbi/blob/0b5f9313f12c9e06b8051e83e5efa58dc7d7f4a7/sbi/utils/sbiutils.py#L360-L372
@Matthijspals has this been fixed in the pairplot PR back then, or not?
If not, would you be up for proving the fix?
@Matthijspals has this been fixed in the
pairplotPR back then, or not? If not, would you be up for proving the fix?
Was not yet fixed - up for doing it later this week!
Actually, another improvement would be to cast inputs to numpy explicitly. The plotting does currently fail if you input a JAX array, which is a bit annoying.
This should be closed with #1185