`spatial_segment` doesn't always work with custom palettes
Description
I'm trying to run sq.pl.spatial_segment with a custom palette of length 32, however, I get the error described later. I believe I've narrowed down the issue to line 471 in _spatial_utils.py: color_vector = color_source_vector.map(color_map, na_action=None)
color_vector needs to be a categorical type to work with other functions, however the map method doesn't always return a categorical type. It appears that palettes with the following sizes work: [16, 17, 18, 19, 20, 32, 40, 46, 53, 56, 57, 64, 71, 79, 80, 88, 92, 94, 95, 106, 107, 117, 124, 131, 136, 138, 140, 142, 143, 158, 159, 160, 166, 169, 174, 177, 184, 194, 195, 196]
I suggest adding the following line after line 471 (or something similar to ensure .map returns a category:
color_vector = color_vector.astype(pd.CategoricalDtype())
Minimal reproducible example
v = sc.pl.palettes.vega_20*10
cats = [None]
for i in range(1,len(v)):
color_map = _get_palette(anndata, cluster_col, anndata.obs[cluster_col].cat.categories, palette=matplotlib.colors.ListedColormap(v[:i]))
color_vector = anndata.obs[cluster_col].map(color_map, na_action=None)
cats.append(isinstance(color_vector.dtype,pd.CategoricalDtype))
np.where(np.array(cats) == True)
This returns the palette sizes for which the result of .map is categorical. The behaviour appears to depend on the number of unique palette options.
Traceback
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[20], line 16
13 raise Exception(f"Number of colors {len(C)} is not enough for {n} categories")
15 fig,ax = plt.subplots(1,1,figsize = (16,8))
---> 16 sq.pl.spatial_segment(
17 adata_x,
18 library_key = lib_key, spatial_key = spatial_key,
19 seg_cell_id = "seg_mask_ids", img = False,
20 color = col, #f"banksy_ct_{i}",
21 frameon = False,
22 fig = fig, ax = ax, #na_color = "white",
23 scalebar_dx = 0.2125*scale_factor,
24 scalebar_units = "µm", palette = mpl.colors.ListedColormap(C[:n+1]),
25 title = f"Banksy Refined CT" #{i}"
26 )
28 handles, labels = ax.get_legend_handles_labels()
29 ax.legend(
30 handles, labels,
31 loc=f'center right',
(...)
35 ncol=1
36 )
File /vast/projects/phipson_combes_kidney_xenium/hm_analysis/envs/sa_py/lib/python3.9/site-packages/squidpy/pl/_spatial.py:464, in spatial_segment(adata, seg_cell_id, seg, seg_key, seg_contourpx, seg_outline, **kwargs)
423 @d.dedent # type: ignore[arg-type]
424 @_wrap_signature
425 def spatial_segment(
(...)
432 **kwargs: Any,
433 ) -> Axes | Sequence[Axes] | None:
434 """
435 Plot spatial omics data with segmentation masks on top.
436
(...)
462 %(spatial_plot.returns)s
463 """
--> 464 return _spatial_plot(
465 adata,
466 seg=seg,
467 seg_key=seg_key,
468 seg_cell_id=seg_cell_id,
469 seg_contourpx=seg_contourpx,
470 seg_outline=seg_outline,
471 **kwargs,
472 )
File /lib/python3.9/site-packages/squidpy/pl/_spatial.py:296, in _spatial_plot(adata, shape, color, groups, library_id, library_key, spatial_key, img, img_res_key, img_alpha, img_cmap, img_channel, seg, seg_key, seg_cell_id, seg_contourpx, seg_outline, use_raw, layer, alt_var, size, size_key, scale_factor, crop_coord, cmap, palette, alpha, norm, na_color, connectivity_key, edges_width, edges_color, library_first, frameon, wspace, hspace, ncols, outline, outline_color, outline_width, legend_loc, legend_fontsize, legend_fontweight, legend_fontoutline, legend_na, colorbar, scalebar_dx, scalebar_units, title, axis_label, fig, ax, return_ax, figsize, dpi, save, scalebar_kwargs, edges_kwargs, **kwargs)
284 ax, cax = _plot_scatter(
285 coords=coords_sub,
286 ax=ax,
(...)
293 **kwargs,
294 )
295 elif _seg is not None and _cell_id is not None:
--> 296 ax, cax = _plot_segment(
297 seg=_seg,
298 cell_id=_cell_id,
299 color_vector=color_vector,
300 color_source_vector=color_source_vector,
301 seg_contourpx=seg_contourpx,
302 seg_outline=seg_outline,
303 na_color=na_color,
304 ax=ax,
305 cmap_params=cmap_params,
306 color_params=color_params,
307 categorical=categorical,
308 **kwargs,
309 )
311 _ = _decorate_axs(
312 ax=ax,
313 cax=cax,
(...)
334 scalebar_kwargs=scalebar_kwargs,
335 )
337 if fig_params.fig is not None and save is not None:
File /lib/python3.9/site-packages/squidpy/pl/_spatial_utils.py:976, in _plot_segment(seg, cell_id, color_vector, color_source_vector, ax, cmap_params, color_params, categorical, seg_contourpx, seg_outline, na_color, **kwargs)
962 def _plot_segment(
963 seg: NDArrayA,
964 cell_id: NDArrayA,
(...)
974 **kwargs: Any,
975 ) -> tuple[Axes, Collection]:
--> 976 img = _map_color_seg(
977 seg=seg,
978 cell_id=cell_id,
979 color_vector=color_vector,
980 color_source_vector=color_source_vector,
981 cmap_params=cmap_params,
982 seg_erosionpx=seg_contourpx,
983 seg_boundaries=seg_outline,
984 na_color=na_color,
985 )
987 _cax = ax.imshow(
988 img,
989 rasterized=True,
(...)
995 **kwargs,
996 )
997 cax = ax.add_image(_cax)
File /lib/python3.9/site-packages/squidpy/pl/_spatial_utils.py:690, in _map_color_seg(seg, cell_id, color_vector, color_source_vector, cmap_params, seg_erosionpx, seg_boundaries, na_color)
688 val_im = map_array(seg, cell_id, cell_id) # replace with same seg id to remove missing segs
689 try:
--> 690 cols = cmap_params.cmap(cmap_params.norm(color_vector))
691 except TypeError:
692 assert all(colors.is_color_like(c) for c in color_vector), "Not all values are color-like."
File ~/.local/lib/python3.9/site-packages/matplotlib/colors.py:1338, in Normalize.__call__(self, value, clip)
1335 result, is_scalar = self.process_value(value)
1337 if self.vmin is None or self.vmax is None:
-> 1338 self.autoscale_None(result)
1339 # Convert at least to float, without losing precision.
1340 (vmin,), _ = self.process_value(self.vmin)
File ~/.local/lib/python3.9/site-packages/matplotlib/colors.py:1382, in Normalize.autoscale_None(self, A)
1380 A = np.asanyarray(A)
1381 if self.vmin is None and A.size:
-> 1382 self.vmin = A.min()
1383 if self.vmax is None and A.size:
1384 self.vmax = A.max()
File ~/.local/lib/python3.9/site-packages/numpy/ma/core.py:5833, in MaskedArray.min(self, axis, out, fill_value, keepdims)
5831 # No explicit output
5832 if out is None:
-> 5833 result = self.filled(fill_value).min(
5834 axis=axis, out=out, **kwargs).view(type(self))
5835 if result.ndim:
5836 # Set the mask
5837 result.__setmask__(newmask)
AttributeError: 'str' object has no attribute 'view'
Version
1.5.0
Hey @harrymueller, thank you for the report and fix suggestion! We're currently in the process of migrating the Squidpy-internal plotting functions to instead rely on spatialdata-plot so we will stop fixing these functions on the Squidpy side. Could you try your visualisation with that package?