Passing both `reference_values` and `labeller` to `plot_pair` is not consistent
Describe the bug
When attempting to include reference_values in plot_pair, the keys of reference_values are compared against the labeled var_names, i.e. after transforming them under the MapLabeller passed via labeller if any. Since plot_pair flattens var_names and further alters the formatting via labeller.make_label_vert, the user doesn't know to update the keys of reference_values to match the formatted, flattened var_names produced within plot_pair.
To Reproduce
import arviz as az
import numpy as np
from arviz.labels import MapLabeller
datadict = {
"x": np.random.randn(1, 100, 2),
"y": np.random.randn(1, 100, 2),
}
trace = az.convert_to_inference_data(datadict)
reference_values = {
"x": np.array([-0.5, 0.5]),
"y": np.array([-0.5, 0.5]),
}
var_name_map = {
"x": "x (cm)",
"y": "y (cm)",
}
labeller = MapLabeller(var_name_map)
az.plot_pair(trace, reference_values=reference_values, labeller=labeller)
The resulting plot does not include the reference value markers and further raises the warning
UserWarning: Argument reference_values does not include reference value for: x (cm) 1, y (cm) 1, x (cm) 0, y (cm) 0
Expected behavior
The plot should contain the reference values without knowledge about how plot_pair alters the label formatting (i.e., changing "x" to "x (cm)\n0" and "x (cm)\n1".
Additional context
arviz 0.22.0
PR incoming...
Furthermore, plot_pair creates reference_values_copy via a hard-coded search-and-replace of " " for "\n", which causes further problems. For example, we can modify the above code snippet to produce something that works, like:
import arviz as az
import numpy as np
from arviz.labels import MapLabeller
datadict = {
"x": np.random.randn(1, 100, 2),
"y": np.random.randn(1, 100, 2),
}
trace = az.convert_to_inference_data(datadict)
reference_values = {
"x 0": -0.5,
"x 1": 0.5,
"y 0": -0.5,
"y 1": 0.5,
}
az.plot_pair(trace, reference_values=reference_values)
This works as intended, but as soon as we add a labeller:
import arviz as az
import numpy as np
from arviz.labels import MapLabeller
datadict = {
"x": np.random.randn(1, 100, 2),
"y": np.random.randn(1, 100, 2),
}
trace = az.convert_to_inference_data(datadict)
reference_values = {
"x 0": -0.5,
"x 1": 0.5,
"y 0": -0.5,
"y 1": 0.5,
}
var_name_map = {
"x": "x (cm)",
"y": "y (cm)",
}
labeller = MapLabeller(var_name_map)
az.plot_pair(trace, reference_values=reference_values, labeller=labeller)
This fails to produce an reference value markers, and raises the warning
UserWarning: Argument reference_values does not include reference value for: x (cm) 1, y (cm) 0, x (cm) 0, y (cm) 1
The user might then be tempted to update the reference_values dictionary like so based on the warning message:
import arviz as az
import numpy as np
from arviz.labels import MapLabeller
datadict = {
"x": np.random.randn(1, 100, 2),
"y": np.random.randn(1, 100, 2),
}
trace = az.convert_to_inference_data(datadict)
reference_values = {
"x (cm) 0": -0.5,
"x (cm) 1": 0.5,
"y (cm) 0": -0.5,
"y (cm) 1": 0.5,
}
var_name_map = {
"x": "x (cm)",
"y": "y (cm)",
}
labeller = MapLabeller(var_name_map)
az.plot_pair(trace, reference_values=reference_values, labeller=labeller)
But this again raises the same warning message:
UserWarning: Argument reference_values does not include reference value for: x (cm) 1, y (cm) 0, x (cm) 0, y (cm) 1
The reason for this warning is that plot_pair is checking the passed reference_values against the hard-coded search-and-replace of " " for "\n" in the variable names, which results in values like so:
flat_var_names
['x (cm)\n0', 'x (cm)\n1', 'y (cm)\n0', 'y (cm)\n1']
reference_values_copy
{'x\n(cm) 0': -0.5, 'x\n(cm) 1': 0.5, 'y\n(cm) 0': -0.5, 'y\n(cm) 1': 0.5}
The search-and-replace has messed up the labels by changing the first space, which is part of the label, into a new line. This unpredictable nature is why this functionality is broken.
The associated pull request makes reference_values and labeller work consistently and as I originally expected/wanted. With this PR, all of these interactions between reference_values, labeller, and combine_dims work. None of these work in the current version.
import arviz as az
import numpy as np
from arviz.labels import MapLabeller
print(az.__version__)
datadict = {
"a": np.random.randn(1, 100),
"b": np.random.randn(1, 100),
"c": np.random.randn(1, 100, 2),
"d": np.random.randn(1, 100, 2),
"x": np.random.randn(1, 100, 2, 1),
"y": np.random.randn(1, 100, 2, 1),
}
trace = az.convert_to_inference_data(datadict)
reference_values = {
"a": 0.0,
"b": 0.0,
"c": [-0.5, 0.5],
"d": [-0.5, 0.5],
"x": np.array([[-0.5], [0.5]]),
"y": np.array([[-0.5], [0.5]]),
}
var_name_map = {
"a": r"$\alpha$ ($\mu$m)",
"b": r"$\beta$ ($\mu$m)",
"x": "x (cm)",
"y": "y (cm)",
}
labeller = MapLabeller(var_name_map)
az.plot_pair(
trace,
var_names=["a", "b"],
reference_values=reference_values,
labeller=labeller,
show=True,
)
az.plot_pair(
trace,
var_names=["c", "d"],
reference_values=reference_values,
labeller=labeller,
show=True,
)
az.plot_pair(
trace,
var_names=["x", "y"],
reference_values=reference_values,
labeller=labeller,
combine_dims={"x_dim_0", "y_dim_0"},
show=True,
)