spatialdata-plot icon indicating copy to clipboard operation
spatialdata-plot copied to clipboard

Plotting multiple elements in the same `ax` seems to work only when `show()` is not called.

Open LucaMarconato opened this issue 2 years ago • 2 comments

I refer to the code mentioned in this other issue: https://github.com/scverse/spatialdata-plot/issues/68

This code here:

    ax = plt.gca()
    sdata.pl.render_shapes(element='s', na_color=(0.5, 0.5, 0.5, 0.5)).pl.render_points().pl.show(ax=ax)
    sdata.pl.render_shapes(element='c', na_color=(0.7, 0.7, 0.7, 0.5)).pl.show(ax=ax)
    plt.show()

doesn't work if I run the code as a script, but it works in interactive mode (where because of a bug the plots are not shown until I call plt.show()). I suggest to do like scanpy and having a parameter show: bool. I suggest also that if the parameter ax is not None, then show is set to False. I don't remember if this one is also a behavior of scanpy, but I think it's reasonable.

LucaMarconato avatar May 14 '23 13:05 LucaMarconato

Hello devs,

I have a really cool function on my hands, and I saving a summary plot is proving to be quite difficult. So I am kinda restarting this issue.

My function would take an image as an input, perform segmentation of the image using Cellpose via SOPA, and produce a PNG file with a hyperparameter search, to decide what is the best segmentation.

Currently I am running this code for plotting each ax object, in a fig that has many axes.

sdata.pl.render_images(
    element=args.image_key, alpha=0.85, channel=config['channels'], palette=['green']
).pl.render_shapes(
    element=title, fill_alpha=0.0, outline=True, outline_width=1.1, outline_color="yellow", outline_alpha=0.32
).pl.show(ax=ax, title=title, save=os.path.join(args.output, "pngs", "segment_search.png"))

When this line is reached in the CLI a matplotlib popup comes up with the entire figure, but a single filled ax object. I have to manually close this first figure, and then the other axes are plotted, and then the entire figure saved (I think overwriting itself).

I have looked into matplotlib docs but I found no clear answer.

Any tips, ideas, or comments, very welcome. For the plotting or the function in general.

Best, Jose

Entire script Function

#system
from loguru import logger
import argparse
import sys
import os
import time

import spatialdata
import spatialdata_plot

#imports
import skimage.segmentation as segmentation
import skimage.io as io
import numpy as np

#yaml 
import yaml
import math
import matplotlib.pyplot as plt
import re
import os
import matplotlib.gridspec as gridspec

#sopa
import sopa.segmentation
import sopa.io

def get_args():
    """ Get arguments from command line """
    description = """Expand labeled masks by a certain number of pixels."""
    parser = argparse.ArgumentParser(description=description, formatter_class=argparse.RawDescriptionHelpFormatter)
    inputs = parser.add_argument_group(title="Required Input", description="Path to required input file")
    inputs.add_argument("-i", "--input",    dest="input",   action="store", required=True, help="File path to input mask or folders with many masks")
    inputs.add_argument("-c", "--config",   dest="config",  action="store", required=True, help="Path to config.yaml for cellpose parameters")
    inputs.add_argument("-o", "--output",   dest="output",  action="store", required=True, help="Path to output mask, or folder where to save the output masks")
    inputs.add_argument("-l", "--log-level",dest="loglevel", default='INFO', choices=["DEBUG", "INFO"], help='Set the log level (default: INFO)')
    arg = parser.parse_args()
    arg.input = os.path.abspath(arg.input)
    arg.config = os.path.abspath(arg.config)
    arg.output = os.path.abspath(arg.output)
    return arg

def check_input_outputs(args):
    """ Check if input and output files exist """
    #input
    assert os.path.isfile(args.input), "Input must be a file"
    assert args.input.endswith((".tif", ".tiff")), "Input file must be a .tif or .tiff file"
    #config
    assert os.path.isfile(args.config), "Config must exist"
    assert args.config.endswith(".yaml"), "Config file must be a .yaml file"
    #output
    if not os.path.exists(args.output):
        os.makedirs(args.output)
    assert os.path.isdir(args.output), "Output must be a folder"
    #create output folders
    os.makedirs(os.path.join(args.output, "pngs"), exist_ok=True)
    ### os.makedirs(os.path.join(args.output, "zarrs"), exist_ok=True)
    args.filename = os.path.basename(args.input).split(".")[0]
    args.zarr_path = os.path.join(args.output, f"{args.filename}.zarr")
    
    logger.info(f"Input, output and config files exist and checked.")

def create_sdata(args):
    """ Create sdata object """
    logger.info(f"Creating spatialdata object.")
    time_start = time.time()

    sdata = sopa.io.ome_tif(args.input)
    args.image_key = list(sdata.images.keys())[0]

    time_end = time.time()
    logger.info(f"Creating spatialdata object took {time_end - time_start} seconds.")
    return sdata

def prepare_for_segmentation_search(sdata, args):
    """ Search for segments in sopa data """
    logger.info(f"Preparing for segmentation search.")
    time_start = time.time()

    patches = sopa.segmentation.Patches2D(sdata, element_name=args.image_key, patch_width=1000, patch_overlap=100)
    patches.write()

    #reset channel names to their indexes, metadata to inconsistent
    new_c = list(range(len(sdata.images[args.image_key]['scale0'].coords['c'].values)))
    sdata.images[args.image_key] = sdata.images[args.image_key].assign_coords(c=new_c)

    time_end = time.time()
    logger.info(f"Preparation for segmentation took {time_end - time_start} seconds.")
    return sdata

def read_yaml(file_path):
    """ Read yaml file """
    logger.info(f"Reading yaml file.")
    with open(file_path, 'r') as file:
        data = yaml.safe_load(file)
    return data

def segmentation_loop(sdata, args, config):
    """ Loop through different cellpose parameters """
    logger.info(f"Starting segmentation loop.")

    for ft in config['flow_thresholds']:
        for cpt in config['cellprob_thresholds']:

            logger.info(f"Segmenting with FT: {ft} and CT: {cpt}")
            FT_str = str(ft).replace(".", "")
            #create method for segmenting
            method = sopa.segmentation.methods.cellpose_patch(
                diameter=config['cell_pixel_diameter'], 
                channels=config['channels'], 
                flow_threshold=ft, 
                cellprob_threshold=cpt, 
                model_type=config['model_type']
            )
            segmentation = sopa.segmentation.StainingSegmentation(sdata, method, channels=config['channels'], min_area=config['min_area'])
            #create temp dir to store segmentation of each tile
            cellpose_temp_dir = os.path.join(args.output, ".sopa_cache", "cellpose", f"run_FT{FT_str}_CPT{cpt}")
            #segment
            segmentation.write_patches_cells(cellpose_temp_dir)
            #read and solve conflicts
            cells = sopa.segmentation.StainingSegmentation.read_patches_cells(cellpose_temp_dir)
            cells = sopa.segmentation.shapes.solve_conflicts(cells)
            #save segmentation of entire image as shapes
            sopa.segmentation.StainingSegmentation.add_shapes(
                sdata, cells, image_key=args.image_key, shapes_key=f"cellpose_boundaries_FT{FT_str}_CT{cpt}")
    
    logger.info(f"Saving zarr to {args.zarr_path}")
    sdata.write(args.zarr_path, overwrite=True)
    logger.info(f"Segmentation loop finished.")

def extract_ft_values(shape_titles):
    """Extract all unique ft values from a list of shape titles."""
    ft_values = set()
    cpt_values = set()
    for title in shape_titles:
        match = re.search(r'_FT(\d+)_CT(\d+)', title)
        if match:
            ft_values.add(match.group(1))
            cpt_values.add(match.group(2))
        else:
            print(f"Warning: {title} does not match the expected pattern.")
    return sorted(ft_values), sorted(cpt_values)

def plot(sdata, args, config):

    shape_titles = list(sdata.shapes.keys())
    shape_titles.remove("sopa_patches")
    logger.info(f"Plotting {shape_titles} segmentations")

    logger.info

    unique_ft_values, unique_cpt_values = extract_ft_values(shape_titles)
    num_cols = len(unique_ft_values)
    num_rows = len(unique_cpt_values)
    logger.info(f"Unique FT values: {unique_ft_values} and Unique CT values: {unique_cpt_values}")
    ft_to_index = {ft: i for i, ft in enumerate(unique_ft_values)}
    cpt_to_index = {cpt: i for i, cpt in enumerate(unique_cpt_values)}
    
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols*6, num_rows*6), facecolor='black')
    gs = gridspec.GridSpec(num_rows, num_cols, wspace=0.1, hspace=0.1)

    for i, title in enumerate(shape_titles):
        #print number of title out of all titles
        logger.info(f"Rendering {i+1}/{len(shape_titles)} ||| {title}")
        ft, cpt = re.search(r'FT(\d+)_CT(\d+)', title).groups()
        row = cpt_to_index[cpt]
        col = ft_to_index[ft]

        ax = fig.add_subplot(gs[row, col])
        ax.set_facecolor('black')
        ax.title.set_color('white')
        ax.set_xticks([])
        ax.set_yticks([])
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['left'].set_visible(False)

        try:
            logger.info(f"  Rendering image")
            sdata.pl.render_images(
                element=args.image_key, alpha=0.85, channel=config['channels'], palette=['green']
            ).pl.render_shapes(
                element=title, fill_alpha=0.0, outline=True, outline_width=1.1, outline_color="yellow", outline_alpha=0.32
            ).pl.show(ax=ax, title=title, save=os.path.join(args.output, "pngs", "segment_search.png"))
            logger.info(f"Saving plot to {os.path.join(args.output, 'pngs', 'segment_search.png')}")
            # plt.savefig(os.path.join(args.output, "pngs", "segment_search.png"))
        except:
            print(f"Could not render shapes of {title}")

def main():
    args = get_args()
    logger.remove()
    logger.add(sys.stdout, format="<green>{time:HH:mm:ss.SS}</green> | <level>{level}</level> | {message}", level=args.loglevel.upper())
    check_input_outputs(args)
    sdata = create_sdata(args)
    sdata = prepare_for_segmentation_search(sdata, args)
    segmentation_loop(sdata, args, config=read_yaml(args.config))
    plot(sdata, args, config=read_yaml(args.config))

if __name__ == "__main__":
    main()

"""
Example:

python ./scripts/segment_search.py \
--input ./data/input/Exemplar001.ome.tif \
--config ./data/configs/config.yaml \
--output ./data/output/
"""


josenimo avatar Sep 04 '24 12:09 josenimo

Note for us. This bug and the newly reported bug https://github.com/scverse/spatialdata-plot/issues/362 are related.

Thanks @josenimo for the bug report. We will try to address this bug soon. Meanwhile, I would suggest checking if setting the matplotlib backend to Agg could work for you. Or maybe using plt.ion()/plt.ioff() as described in this other issue https://github.com/scverse/spatialdata-plot/issues/68.

LucaMarconato avatar Sep 30 '24 18:09 LucaMarconato