spatialdata-plot
spatialdata-plot copied to clipboard
KeyError when plotting one of multiple labels images with palette
Given a SpatialData with multiple labels, when rendering a single labels image with a palette matching table values assigned to it, a key error is raised.
Example
import numpy as np
import pandas as pd
from anndata import AnnData
from spatialdata import SpatialData
from spatialdata.models import Labels2DModel, TableModel
import spatialdata_plot # noqa
def test_cat_with_palette():
sdata = SpatialData(
labels={
"labels1": Labels2DModel.parse(
np.array([[0, 1, 2]]), dims=("y", "x")
),
"labels2": Labels2DModel.parse(
np.array([[0, 1, 2]]), dims=("y", "x")
),
},
table=TableModel.parse(
AnnData(
obs=pd.DataFrame(
{
"region": ["labels1", "labels1", "labels2", "labels2"],
"instance_id": [1, 2, 1, 2],
"value_to_plot": pd.Series(
["cat1", "cat1", "cat2", "cat2"], dtype="category"
),
}
)
),
region=["labels1", "labels2"],
region_key="region",
instance_key="instance_id",
),
)
# labels1's value_to_plot has 1 distinct value.
sdata.pl.render_labels(
elements="labels1", color="value_to_plot", palette=["red"]
).pl.show()
Cause
The single-table design of SpatialData has as a consequence that categoricals contain values of all regions. When spatialdata-plot subsets a region for rendering, the category still contains values of other regions. The color vector is determined from the actual label values in the rendered labels image. However, when generating the legend, it iterates over all category values including those not associated with the rendered labels image.
The solution is not to provide a palette for all category values (same exception occurs). A category could also be the "region" column, which has one distinct value per labels image. Extending the palette would not be practical considering that SpatialData can have an arbitrary amount of independent labels images.
The error is raised in Scanpy, but the arguments passed to it are not correct.
../../src/spatialdata_plot/pl/basic.py:781: in show
_render_labels(
../../src/spatialdata_plot/pl/render.py:687: in _render_labels
_ = _decorate_axs(
../../src/spatialdata_plot/pl/utils.py:811: in _decorate_axs
_add_categorical_legend(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
ax = <Axes: >
color_source_vector = ['cat1', 'cat1']
Categories (2, object): ['cat1', 'cat2']
palette = {'cat1': '#ff00004c'}, legend_loc = 'right margin'
legend_fontweight = 'bold', legend_fontsize = None, legend_fontoutline = []
multi_panel = False, na_color = [(0.0, 0.0, 0.0, 0.0)], na_in_legend = True
scatter_array = None
def _add_categorical_legend(
ax,
color_source_vector,
palette: dict,
legend_loc: str,
legend_fontweight,
legend_fontsize,
legend_fontoutline,
multi_panel,
na_color,
na_in_legend: bool,
scatter_array=None,
):
"""Add a legend to the passed Axes."""
if na_in_legend and pd.isnull(color_source_vector).any():
if "NA" in color_source_vector:
raise NotImplementedError(
"No fallback for null labels has been defined if NA already in categories."
)
color_source_vector = color_source_vector.add_categories("NA").fillna("NA")
palette = palette.copy()
palette["NA"] = na_color
cats = color_source_vector.categories
if multi_panel is True:
# Shrink current axis by 10% to fit legend and match
# size of plots that are not categorical
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.91, box.height])
if legend_loc == 'right margin':
for label in cats:
> ax.scatter([], [], c=palette[label], label=label)
E KeyError: 'cat2'
…/lib/python3.9/site-packages/scanpy/plotting/_tools/scatterplots.py:1101: KeyError