STREAM icon indicating copy to clipboard operation
STREAM copied to clipboard

Increase the line size in STREAM plot

Open smk5g5 opened this issue 3 years ago • 3 comments

I am trying to use a STREAM figure for publication and I did not see any option in the plot_stream function where the line size can be increased so it is more visible. I am trying to use this figure for publication and I was wondering if there is a way with STREAM that I can increase the line size for this plot? Thanks!

image

smk5g5 avatar Dec 04 '20 19:12 smk5g5

Hi, I am glad that STREAM analysis helped in your research. Unfortunately the lines (or edges) in stream plot are disabled by default and this part currently is hard-coded.

But adjusting the two parameters in plot_stream(): dist_scale (controls the width of STREAM plot branches.), log_scale (shows stream plot in log scale to zoom in on thin branches) will help fine-tune the appearance of stream plot, which may help make the thin part more visible.

huidongchen avatar Dec 04 '20 19:12 huidongchen

Is it something that can be changed by changing seaborn configuration settings in STREAM. I can try make changes to the code to achieve what I want if you can point me as to where it needs to be changed!. Changing dist_scale did not achieve desired result for me!

smk5g5 avatar Dec 04 '20 20:12 smk5g5

You need to add the parameter lw for clip_path = Polygon(verts_cell, facecolor='none', closed=True)

If you import extra.py and copy&paste the following function, it should do the trick

def plot_stream(adata,root='S0',color = None,preference=None,dist_scale=0.9,
                factor_num_win=10,factor_min_win=2.0,factor_width=2.5,factor_nrow=200,factor_ncol=400,
                log_scale = False,factor_zoomin=100.0,
                fig_size=(7,4.5),fig_legend_order=None,fig_legend_ncol=1,
                vmin=None,vmax=None,
                pad=1.08,w_pad=None,h_pad=None,
                save_fig=False,fig_path=None,fig_format='pdf'):  
    """Generate stream plot at density level
    
    Parameters
    ----------
    adata: AnnData
        Annotated data matrix.
    root: `str`, optional (default: 'S0'): 
        The starting node
    color: `list` optional (default: None)
        Column names of observations (adata.obs.columns) or variable names(adata.var_names). A list of names to be plotted. 
    preference: `list`, optional (default: None): 
        The preference of nodes. The branch with speficied nodes are preferred and put on the top part of stream plot. 
        The higher ranks the node have, the closer to the top the branch with that node is.
    dist_scale: `float`,optional (default: 0.9)
        Scaling factor. It controls the width of STREAM plot branches. The smaller, the thinner the branch will be.
    factor_num_win: `int`, optional (default: 10)
        Number of sliding windows used for making stream plot. It controls the smoothness of STREAM plot.
    factor_min_win: `float`, optional (default: 2.0)
        The minimum number of sliding windows. It controls the resolution of STREAM plot. The window size is calculated based on shortest branch. (suggested range: 1.5~3.0)
    factor_width: `float`, optional (default: 2.5)
        The ratio between length and width of stream plot. 
    factor_nrow: `int`, optional (default: 200)
        The number of rows in the array used to plot continuous values 
    factor_ncol: `int`, optional (default: 400)
        The number of columns in the array used to plot continuous values
    log_scale: `bool`, optional (default: False)
        If True,the number of cells (the width) is logarithmized when drawing stream plot.
    factor_zoomin: `float`, optional (default: 100.0)
        If log_scale is True, the factor used to zoom in the thin branches
    fig_size: `tuple`, optional (default: (7,4.5))
        figure size.
    fig_legend_order: `dict`,optional (default: None)
        Specified order for the appearance of the annotation keys.Only valid for ategorical variable  
        e.g. fig_legend_order = {'ann1':['a','b','c'],'ann2':['aa','bb','cc']}
    fig_legend_ncol: `int`, optional (default: 1)
        The number of columns that the legend has.
    vmin,vmax: `float`, optional (default: None)
        The min and max values are used to normalize continuous values. If None, the respective min and max of continuous values is used.
    pad: `float`, optional (default: 1.08)
        Padding between the figure edge and the edges of subplots, as a fraction of the font size.
    h_pad, w_pad: `float`, optional (default: None)
        Padding (height/width) between edges of adjacent subplots, as a fraction of the font size. Defaults to pad.
    save_fig: `bool`, optional (default: False)
        if True,save the figure.
    fig_path: `str`, optional (default: None)
        if save_fig is True, specify figure path. if None, adata.uns['workdir'] will be used.
    fig_format: `str`, optional (default: 'pdf')
        if save_fig is True, specify figure format.
    Returns
    -------
    None
    """

    if(fig_path is None):
        fig_path = adata.uns['workdir']
    fig_size = mpl.rcParams['figure.figsize'] if fig_size is None else fig_size

    if(color is None):
        color = ['label']
    ###remove duplicate keys
    color = list(dict.fromkeys(color))     

    dict_ann = dict()
    for ann in color:
        if(ann in adata.obs.columns):
            dict_ann[ann] = adata.obs[ann]
        elif(ann in adata.var_names):
            dict_ann[ann] = adata.obs_vector(ann)
        else:
            raise ValueError("could not find '%s' in `adata.obs.columns` and `adata.var_names`"  % (ann))
    
    flat_tree = adata.uns['flat_tree']
    ft_node_label = nx.get_node_attributes(flat_tree,'label')
    label_to_node = {value: key for key,value in nx.get_node_attributes(flat_tree,'label').items()}    
    if(root not in label_to_node.keys()):
        raise ValueError("There is no root '%s'" % root)  

    if(preference!=None):
        preference_nodes = [label_to_node[x] for x in preference]
    else:
        preference_nodes = None

    legend_order = {ann:np.unique(dict_ann[ann]) for ann in color if is_string_dtype(dict_ann[ann])}
    if(fig_legend_order is not None):
        if(not isinstance(fig_legend_order, dict)):
            raise TypeError("`fig_legend_order` must be a dictionary")
        for ann in fig_legend_order.keys():
            if(ann in legend_order.keys()):
                legend_order[ann] = fig_legend_order[ann]
            else:
                print("'%s' is ignored for ordering legend labels due to incorrect name or data type" % ann)

    dict_plot = dict()
    
    list_string_type = [k for k,v in dict_ann.items() if is_string_dtype(v)]
    if(len(list_string_type)>0):
        dict_verts,dict_extent = \
        cal_stream_polygon_string(adata,dict_ann,root=root,preference=preference,dist_scale=dist_scale,
                                  factor_num_win=factor_num_win,factor_min_win=factor_min_win,factor_width=factor_width,
                                  log_scale=log_scale,factor_zoomin=factor_zoomin)  
        dict_plot['string'] = [dict_verts,dict_extent]

    list_numeric_type = [k for k,v in dict_ann.items() if is_numeric_dtype(v)]
    if(len(list_numeric_type)>0):
        verts,extent,ann_order,dict_ann_df,dict_im_array = \
        cal_stream_polygon_numeric(adata,dict_ann,root=root,preference=preference,dist_scale=dist_scale,
                                   factor_num_win=factor_num_win,factor_min_win=factor_min_win,factor_width=factor_width,
                                   factor_nrow=factor_nrow,factor_ncol=factor_ncol,
                                   log_scale=log_scale,factor_zoomin=factor_zoomin)     
        dict_plot['numeric'] = [verts,extent,ann_order,dict_ann_df,dict_im_array]
        
    for ann in color:  
        if(is_string_dtype(dict_ann[ann])):
            if(not ((ann+'_color' in adata.uns_keys()) and (set(adata.uns[ann+'_color'].keys()) >= set(np.unique(dict_ann[ann]))))):
                ### a hacky way to generate colors from seaborn
                tmp = pd.DataFrame(index=adata.obs_names,
                                   data=np.random.rand(adata.shape[0], 2))
                tmp[ann] = dict_ann[ann]
                fig = plt.figure(figsize=fig_size)
                sc_i=sns.scatterplot(x=0,y=1,hue=ann,data=tmp,linewidth=0)
                colors_sns = sc_i.get_children()[0].get_facecolors()
                plt.close(fig)
                colors_sns_scaled = (255*colors_sns).astype(int)
                adata.uns[ann+'_color'] = {tmp[ann][i]:'#%02x%02x%02x' % (colors_sns_scaled[i][0], colors_sns_scaled[i][1], colors_sns_scaled[i][2])
                                           for i in np.unique(tmp[ann],return_index=True)[1]}
            dict_palette = adata.uns[ann+'_color']

            verts = dict_plot['string'][0][ann]
            extent = dict_plot['string'][1][ann]
            xmin = extent['xmin']
            xmax = extent['xmax']
            ymin = extent['ymin'] - (extent['ymax'] - extent['ymin'])*0.1
            ymax = extent['ymax'] + (extent['ymax'] - extent['ymin'])*0.1            
            
            fig = plt.figure(figsize=fig_size)
            ax = fig.add_subplot(1,1,1)
            legend_labels = []
            for ann_i in legend_order[ann]:
                legend_labels.append(ann_i)
                verts_cell = verts[ann_i]
                polygon = Polygon(verts_cell,closed=True,color=dict_palette[ann_i],alpha=0.8,lw=0)
                ax.add_patch(polygon)
            ax.legend(legend_labels,bbox_to_anchor=(1.03, 0.5), loc='center left', ncol=fig_legend_ncol,frameon=False,  
                      columnspacing=0.4,
                      borderaxespad=0.2,
                      handletextpad=0.3,)        
        else:
            verts = dict_plot['numeric'][0] 
            extent = dict_plot['numeric'][1]
            ann_order = dict_plot['numeric'][2]
            dict_ann_df = dict_plot['numeric'][3]  
            dict_im_array = dict_plot['numeric'][4]
            xmin = extent['xmin']
            xmax = extent['xmax']
            ymin = extent['ymin'] - (extent['ymax'] - extent['ymin'])*0.1
            ymax = extent['ymax'] + (extent['ymax'] - extent['ymin'])*0.1

            #clip parts according to determined polygon
            fig = plt.figure(figsize=fig_size)
            ax = fig.add_subplot(1,1,1)
            for ann_i in ann_order:
                vmin_i = dict_ann_df[ann].loc[ann_i,:].min() if vmin is None else vmin
                vmax_i = dict_ann_df[ann].loc[ann_i,:].max() if vmax is None else vmax
                im = ax.imshow(dict_im_array[ann][ann_i],interpolation='bicubic',
                               extent=[xmin,xmax,ymin,ymax],vmin=vmin_i,vmax=vmax_i,aspect='auto') 
                verts_cell = verts[ann_i]
                clip_path = Polygon(verts_cell, facecolor='none', closed=True, lw=1)
                ax.add_patch(clip_path)
                im.set_clip_path(clip_path)
                cbar = plt.colorbar(im, ax=ax, pad=0.04, fraction=0.02, aspect='auto')
                cbar.ax.locator_params(nbins=5)  
        ax.set_xlim(xmin,xmax)
        ax.set_ylim(ymin,ymax)
        ax.set_xlabel("pseudotime",labelpad=2)
        ax.spines['left'].set_visible(False) 
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False) 
        ax.get_yaxis().set_visible(False)
        ax.locator_params(axis='x',nbins=8)
        ax.tick_params(axis="x",pad=-1)
        annots = arrowed_spines(ax, locations=('bottom right',),
                                lw=ax.spines['bottom'].get_linewidth()*1e-5)
        ax.set_title(ann)
        plt.tight_layout(pad=pad, h_pad=h_pad, w_pad=w_pad)             
        if(save_fig):
            file_path_S = os.path.join(fig_path,root)
            if(not os.path.exists(file_path_S)):
                os.makedirs(file_path_S) 
            plt.savefig(os.path.join(file_path_S,'stream_' + slugify(ann) + '.' + fig_format),pad_inches=1,bbox_inches='tight')
            plt.close(fig)

huidongchen avatar Dec 04 '20 20:12 huidongchen