spatialdata-plot
spatialdata-plot copied to clipboard
Plotting multiple elements in the same `ax` seems to work only when `show()` is not called.
I refer to the code mentioned in this other issue: https://github.com/scverse/spatialdata-plot/issues/68
This code here:
ax = plt.gca()
sdata.pl.render_shapes(element='s', na_color=(0.5, 0.5, 0.5, 0.5)).pl.render_points().pl.show(ax=ax)
sdata.pl.render_shapes(element='c', na_color=(0.7, 0.7, 0.7, 0.5)).pl.show(ax=ax)
plt.show()
doesn't work if I run the code as a script, but it works in interactive mode (where because of a bug the plots are not shown until I call plt.show()). I suggest to do like scanpy and having a parameter show: bool. I suggest also that if the parameter ax is not None, then show is set to False. I don't remember if this one is also a behavior of scanpy, but I think it's reasonable.
Hello devs,
I have a really cool function on my hands, and I saving a summary plot is proving to be quite difficult. So I am kinda restarting this issue.
My function would take an image as an input, perform segmentation of the image using Cellpose via SOPA, and produce a PNG file with a hyperparameter search, to decide what is the best segmentation.
Currently I am running this code for plotting each ax object, in a fig that has many axes.
sdata.pl.render_images(
element=args.image_key, alpha=0.85, channel=config['channels'], palette=['green']
).pl.render_shapes(
element=title, fill_alpha=0.0, outline=True, outline_width=1.1, outline_color="yellow", outline_alpha=0.32
).pl.show(ax=ax, title=title, save=os.path.join(args.output, "pngs", "segment_search.png"))
When this line is reached in the CLI a matplotlib popup comes up with the entire figure, but a single filled ax object. I have to manually close this first figure, and then the other axes are plotted, and then the entire figure saved (I think overwriting itself).
I have looked into matplotlib docs but I found no clear answer.
Any tips, ideas, or comments, very welcome. For the plotting or the function in general.
Best, Jose
Entire script Function
#system
from loguru import logger
import argparse
import sys
import os
import time
import spatialdata
import spatialdata_plot
#imports
import skimage.segmentation as segmentation
import skimage.io as io
import numpy as np
#yaml
import yaml
import math
import matplotlib.pyplot as plt
import re
import os
import matplotlib.gridspec as gridspec
#sopa
import sopa.segmentation
import sopa.io
def get_args():
""" Get arguments from command line """
description = """Expand labeled masks by a certain number of pixels."""
parser = argparse.ArgumentParser(description=description, formatter_class=argparse.RawDescriptionHelpFormatter)
inputs = parser.add_argument_group(title="Required Input", description="Path to required input file")
inputs.add_argument("-i", "--input", dest="input", action="store", required=True, help="File path to input mask or folders with many masks")
inputs.add_argument("-c", "--config", dest="config", action="store", required=True, help="Path to config.yaml for cellpose parameters")
inputs.add_argument("-o", "--output", dest="output", action="store", required=True, help="Path to output mask, or folder where to save the output masks")
inputs.add_argument("-l", "--log-level",dest="loglevel", default='INFO', choices=["DEBUG", "INFO"], help='Set the log level (default: INFO)')
arg = parser.parse_args()
arg.input = os.path.abspath(arg.input)
arg.config = os.path.abspath(arg.config)
arg.output = os.path.abspath(arg.output)
return arg
def check_input_outputs(args):
""" Check if input and output files exist """
#input
assert os.path.isfile(args.input), "Input must be a file"
assert args.input.endswith((".tif", ".tiff")), "Input file must be a .tif or .tiff file"
#config
assert os.path.isfile(args.config), "Config must exist"
assert args.config.endswith(".yaml"), "Config file must be a .yaml file"
#output
if not os.path.exists(args.output):
os.makedirs(args.output)
assert os.path.isdir(args.output), "Output must be a folder"
#create output folders
os.makedirs(os.path.join(args.output, "pngs"), exist_ok=True)
### os.makedirs(os.path.join(args.output, "zarrs"), exist_ok=True)
args.filename = os.path.basename(args.input).split(".")[0]
args.zarr_path = os.path.join(args.output, f"{args.filename}.zarr")
logger.info(f"Input, output and config files exist and checked.")
def create_sdata(args):
""" Create sdata object """
logger.info(f"Creating spatialdata object.")
time_start = time.time()
sdata = sopa.io.ome_tif(args.input)
args.image_key = list(sdata.images.keys())[0]
time_end = time.time()
logger.info(f"Creating spatialdata object took {time_end - time_start} seconds.")
return sdata
def prepare_for_segmentation_search(sdata, args):
""" Search for segments in sopa data """
logger.info(f"Preparing for segmentation search.")
time_start = time.time()
patches = sopa.segmentation.Patches2D(sdata, element_name=args.image_key, patch_width=1000, patch_overlap=100)
patches.write()
#reset channel names to their indexes, metadata to inconsistent
new_c = list(range(len(sdata.images[args.image_key]['scale0'].coords['c'].values)))
sdata.images[args.image_key] = sdata.images[args.image_key].assign_coords(c=new_c)
time_end = time.time()
logger.info(f"Preparation for segmentation took {time_end - time_start} seconds.")
return sdata
def read_yaml(file_path):
""" Read yaml file """
logger.info(f"Reading yaml file.")
with open(file_path, 'r') as file:
data = yaml.safe_load(file)
return data
def segmentation_loop(sdata, args, config):
""" Loop through different cellpose parameters """
logger.info(f"Starting segmentation loop.")
for ft in config['flow_thresholds']:
for cpt in config['cellprob_thresholds']:
logger.info(f"Segmenting with FT: {ft} and CT: {cpt}")
FT_str = str(ft).replace(".", "")
#create method for segmenting
method = sopa.segmentation.methods.cellpose_patch(
diameter=config['cell_pixel_diameter'],
channels=config['channels'],
flow_threshold=ft,
cellprob_threshold=cpt,
model_type=config['model_type']
)
segmentation = sopa.segmentation.StainingSegmentation(sdata, method, channels=config['channels'], min_area=config['min_area'])
#create temp dir to store segmentation of each tile
cellpose_temp_dir = os.path.join(args.output, ".sopa_cache", "cellpose", f"run_FT{FT_str}_CPT{cpt}")
#segment
segmentation.write_patches_cells(cellpose_temp_dir)
#read and solve conflicts
cells = sopa.segmentation.StainingSegmentation.read_patches_cells(cellpose_temp_dir)
cells = sopa.segmentation.shapes.solve_conflicts(cells)
#save segmentation of entire image as shapes
sopa.segmentation.StainingSegmentation.add_shapes(
sdata, cells, image_key=args.image_key, shapes_key=f"cellpose_boundaries_FT{FT_str}_CT{cpt}")
logger.info(f"Saving zarr to {args.zarr_path}")
sdata.write(args.zarr_path, overwrite=True)
logger.info(f"Segmentation loop finished.")
def extract_ft_values(shape_titles):
"""Extract all unique ft values from a list of shape titles."""
ft_values = set()
cpt_values = set()
for title in shape_titles:
match = re.search(r'_FT(\d+)_CT(\d+)', title)
if match:
ft_values.add(match.group(1))
cpt_values.add(match.group(2))
else:
print(f"Warning: {title} does not match the expected pattern.")
return sorted(ft_values), sorted(cpt_values)
def plot(sdata, args, config):
shape_titles = list(sdata.shapes.keys())
shape_titles.remove("sopa_patches")
logger.info(f"Plotting {shape_titles} segmentations")
logger.info
unique_ft_values, unique_cpt_values = extract_ft_values(shape_titles)
num_cols = len(unique_ft_values)
num_rows = len(unique_cpt_values)
logger.info(f"Unique FT values: {unique_ft_values} and Unique CT values: {unique_cpt_values}")
ft_to_index = {ft: i for i, ft in enumerate(unique_ft_values)}
cpt_to_index = {cpt: i for i, cpt in enumerate(unique_cpt_values)}
fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols*6, num_rows*6), facecolor='black')
gs = gridspec.GridSpec(num_rows, num_cols, wspace=0.1, hspace=0.1)
for i, title in enumerate(shape_titles):
#print number of title out of all titles
logger.info(f"Rendering {i+1}/{len(shape_titles)} ||| {title}")
ft, cpt = re.search(r'FT(\d+)_CT(\d+)', title).groups()
row = cpt_to_index[cpt]
col = ft_to_index[ft]
ax = fig.add_subplot(gs[row, col])
ax.set_facecolor('black')
ax.title.set_color('white')
ax.set_xticks([])
ax.set_yticks([])
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
try:
logger.info(f" Rendering image")
sdata.pl.render_images(
element=args.image_key, alpha=0.85, channel=config['channels'], palette=['green']
).pl.render_shapes(
element=title, fill_alpha=0.0, outline=True, outline_width=1.1, outline_color="yellow", outline_alpha=0.32
).pl.show(ax=ax, title=title, save=os.path.join(args.output, "pngs", "segment_search.png"))
logger.info(f"Saving plot to {os.path.join(args.output, 'pngs', 'segment_search.png')}")
# plt.savefig(os.path.join(args.output, "pngs", "segment_search.png"))
except:
print(f"Could not render shapes of {title}")
def main():
args = get_args()
logger.remove()
logger.add(sys.stdout, format="<green>{time:HH:mm:ss.SS}</green> | <level>{level}</level> | {message}", level=args.loglevel.upper())
check_input_outputs(args)
sdata = create_sdata(args)
sdata = prepare_for_segmentation_search(sdata, args)
segmentation_loop(sdata, args, config=read_yaml(args.config))
plot(sdata, args, config=read_yaml(args.config))
if __name__ == "__main__":
main()
"""
Example:
python ./scripts/segment_search.py \
--input ./data/input/Exemplar001.ome.tif \
--config ./data/configs/config.yaml \
--output ./data/output/
"""
Note for us. This bug and the newly reported bug https://github.com/scverse/spatialdata-plot/issues/362 are related.
Thanks @josenimo for the bug report. We will try to address this bug soon. Meanwhile, I would suggest checking if setting the matplotlib backend to Agg could work for you. Or maybe using plt.ion()/plt.ioff() as described in this other issue https://github.com/scverse/spatialdata-plot/issues/68.