spatialdata-plot icon indicating copy to clipboard operation
spatialdata-plot copied to clipboard

KeyError when plotting one of multiple labels images with palette

Open aeisenbarth opened this issue 1 year ago • 0 comments

Given a SpatialData with multiple labels, when rendering a single labels image with a palette matching table values assigned to it, a key error is raised.

Example

import numpy as np
import pandas as pd
from anndata import AnnData
from spatialdata import SpatialData
from spatialdata.models import Labels2DModel, TableModel
import spatialdata_plot  # noqa


def test_cat_with_palette():
    sdata = SpatialData(
        labels={
            "labels1": Labels2DModel.parse(
                np.array([[0, 1, 2]]), dims=("y", "x")
            ),
            "labels2": Labels2DModel.parse(
                np.array([[0, 1, 2]]), dims=("y", "x")
            ),
        },
        table=TableModel.parse(
            AnnData(
                obs=pd.DataFrame(
                    {
                        "region": ["labels1", "labels1", "labels2", "labels2"],
                        "instance_id": [1, 2, 1, 2],
                        "value_to_plot": pd.Series(
                            ["cat1", "cat1", "cat2", "cat2"], dtype="category"
                        ),
                    }
                )
            ),
            region=["labels1", "labels2"],
            region_key="region",
            instance_key="instance_id",
        ),
    )
    # labels1's value_to_plot has 1 distinct value.
    sdata.pl.render_labels(
        elements="labels1", color="value_to_plot", palette=["red"]
    ).pl.show()

Cause

The single-table design of SpatialData has as a consequence that categoricals contain values of all regions. When spatialdata-plot subsets a region for rendering, the category still contains values of other regions. The color vector is determined from the actual label values in the rendered labels image. However, when generating the legend, it iterates over all category values including those not associated with the rendered labels image.

The solution is not to provide a palette for all category values (same exception occurs). A category could also be the "region" column, which has one distinct value per labels image. Extending the palette would not be practical considering that SpatialData can have an arbitrary amount of independent labels images.

The error is raised in Scanpy, but the arguments passed to it are not correct.

../../src/spatialdata_plot/pl/basic.py:781: in show
    _render_labels(
../../src/spatialdata_plot/pl/render.py:687: in _render_labels
    _ = _decorate_axs(
../../src/spatialdata_plot/pl/utils.py:811: in _decorate_axs
    _add_categorical_legend(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

ax = <Axes: >
color_source_vector = ['cat1', 'cat1']
Categories (2, object): ['cat1', 'cat2']
palette = {'cat1': '#ff00004c'}, legend_loc = 'right margin'
legend_fontweight = 'bold', legend_fontsize = None, legend_fontoutline = []
multi_panel = False, na_color = [(0.0, 0.0, 0.0, 0.0)], na_in_legend = True
scatter_array = None

    def _add_categorical_legend(
        ax,
        color_source_vector,
        palette: dict,
        legend_loc: str,
        legend_fontweight,
        legend_fontsize,
        legend_fontoutline,
        multi_panel,
        na_color,
        na_in_legend: bool,
        scatter_array=None,
    ):
        """Add a legend to the passed Axes."""
        if na_in_legend and pd.isnull(color_source_vector).any():
            if "NA" in color_source_vector:
                raise NotImplementedError(
                    "No fallback for null labels has been defined if NA already in categories."
                )
            color_source_vector = color_source_vector.add_categories("NA").fillna("NA")
            palette = palette.copy()
            palette["NA"] = na_color
        cats = color_source_vector.categories
    
        if multi_panel is True:
            # Shrink current axis by 10% to fit legend and match
            # size of plots that are not categorical
            box = ax.get_position()
            ax.set_position([box.x0, box.y0, box.width * 0.91, box.height])
    
        if legend_loc == 'right margin':
            for label in cats:
>               ax.scatter([], [], c=palette[label], label=label)
E               KeyError: 'cat2'

…/lib/python3.9/site-packages/scanpy/plotting/_tools/scatterplots.py:1101: KeyError

aeisenbarth avatar Jan 10 '24 20:01 aeisenbarth