ehrapy icon indicating copy to clipboard operation
ehrapy copied to clipboard

revamp existing plots

Open Zethson opened this issue 7 months ago • 1 comments

Description of feature

  • [ ] Investigate the complexity of holoviz & discuss with @eroell and @Zethson
  • [ ] Agree on a final plotting library with @eroell and @Zethson
  • [ ] Investigate what the general plotting API should look like. Which object should be returned? Do we need to add an ep.settings.plot_backend setting? How are users supposed to show and save plots?
  • [ ] Reimplement existing plots in ehrapy that do NOT come from scanpy with the plotting library of choice. Starting points are the survival analysis plots
  • [ ] Ensure that all of these plots are tested (check this for an example) and have a reproducibility notebook
  • [ ] Generate a couple of example ehrdata_blobs datasets and investigate how they could be visualized. Discuss the proposals on zulip
  • [ ] Implement the plots one by one

See also https://github.com/theislab/ehrapy/issues/663 & https://github.com/theislab/ehrapy/issues/666 & https://github.com/theislab/ehrapy/issues/232

Zethson avatar May 20 '25 19:05 Zethson

The KMF implementation with holoviz. It's a pain to autogenerate that as LLMs are really stupid with holoviz.

import holoviews as hv
import numpy as np
import pandas as pd

hv.extension('bokeh')

def kaplan_meier(
    kmfs,
    *,
    display_survival_statistics: bool = False,
    ci_alpha: list[float] | None = None,
    ci_force_lines: list[bool] | None = None,
    ci_show: list[bool] | None = None,
    ci_legend: list[bool] | None = None,
    at_risk_counts: list[bool] | None = None,
    color: list[str] | None = None,
    grid: bool | None = False,
    xlim: tuple[float, float] | None = None,
    ylim: tuple[float, float] | None = None,
    xlabel: str | None = None,
    ylabel: str | None = None,
    width: int = 800,
    height: int = 400,
    show: bool | None = None,
    title: str | None = None,
) -> hv.Layout | hv.Overlay | None:
    
    if not hasattr(kmfs, '__iter__') or isinstance(kmfs, str):
        kmfs = [kmfs]
    
    if ci_alpha is None:
        ci_alpha = [0.3] * len(kmfs)
    if ci_force_lines is None:
        ci_force_lines = [False] * len(kmfs)
    if ci_show is None:
        ci_show = [True] * len(kmfs)
    if ci_legend is None:
        ci_legend = [False] * len(kmfs)
    if at_risk_counts is None:
        at_risk_counts = [False] * len(kmfs)
    if color is None:
        color = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', 
                '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'] * ((len(kmfs) // 10) + 1)

    curves = []
    
    for i, kmf in enumerate(kmfs):
        timeline = kmf.timeline
        survival = kmf.survival_function_.iloc[:, 0]
        
        label = kmf.label if kmf.label else f"Group {i + 1}"
        
        curve_data = [(t, s, label) for t, s in zip(timeline, survival)]
        
        curve = hv.Curve(
            curve_data,
            kdims=['Time'], 
            vdims=['Survival', 'Group'],
            label=label
        ).opts(
            color=color[i],
            line_width=2,
            tools=['hover'],
            hover_tooltips=[('Group', '@Group'), ('Time', '@Time{0.00}'), ('Survival', '@Survival{0.000}')]
        )
        
        if ci_show[i] and hasattr(kmf, 'confidence_interval_'):
            ci_lower = kmf.confidence_interval_.iloc[:, 0]
            ci_upper = kmf.confidence_interval_.iloc[:, 1]
            
            if ci_force_lines[i]:
                ci_lower_data = [(t, l, f"{label} CI Lower") for t, l in zip(timeline, ci_lower)]
                ci_upper_data = [(t, u, f"{label} CI Upper") for t, u in zip(timeline, ci_upper)]
                
                ci_lower_curve = hv.Curve(
                    ci_lower_data,
                    kdims=['Time'], 
                    vdims=['CI_Lower', 'Group'],
                    label=f"{label} CI Lower" if ci_legend[i] else ""
                ).opts(
                    color=color[i], 
                    line_dash='dashed', 
                    alpha=ci_alpha[i],
                    tools=['hover'],
                    hover_tooltips=[('Group', '@Group'), ('Time', '@Time{0.00}'), ('CI Lower', '@CI_Lower{0.000}')]
                )
                
                ci_upper_curve = hv.Curve(
                    ci_upper_data,
                    kdims=['Time'], 
                    vdims=['CI_Upper', 'Group'],
                    label=f"{label} CI Upper" if ci_legend[i] else ""
                ).opts(
                    color=color[i], 
                    line_dash='dashed', 
                    alpha=ci_alpha[i],
                    tools=['hover'],
                    hover_tooltips=[('Group', '@Group'), ('Time', '@Time{0.00}'), ('CI Upper', '@CI_Upper{0.000}')]
                )
                
                curve = curve * ci_lower_curve * ci_upper_curve
            else:
                area_data = []
                for t, l in zip(timeline, ci_lower):
                    area_data.append((t, l))
                for t, u in zip(timeline[::-1], ci_upper[::-1]):
                    area_data.append((t, u))
                
                if area_data:
                    ci_area = hv.Polygons([area_data], label=f"{label} CI").opts(
                        fill_color=color[i],
                        fill_alpha=ci_alpha[i],
                        line_alpha=0,
                        tools=['hover'],
                        hover_tooltips=[('Group', f'{label}'), ('Confidence Interval', 'Yes')]
                    )
                    curve = ci_area * curve
        
        curves.append(curve)
    
    opts_dict = {
        'width': width,
        'height': height,
        'xlabel': xlabel or 'Time',
        'ylabel': ylabel or 'Survival Probability',
        'show_grid': grid,
        'legend_position': 'top_right',
        'tools': ['pan', 'wheel_zoom', 'box_zoom', 'reset', 'save']
    }
    
    if title:
        opts_dict['title'] = title
    if xlim:
        opts_dict['xlim'] = xlim
    if ylim:
        opts_dict['ylim'] = ylim
        
    plot = hv.Overlay(curves).opts(**opts_dict)
    
    if display_survival_statistics:
        all_times = np.concatenate([kmf.timeline for kmf in kmfs])
        time_points = np.linspace(all_times.min(), all_times.max(), 6)
        
        table_data = []
        for i, kmf in enumerate(kmfs):
            label = kmf.label if kmf.label else f"Group {i + 1}"
            survival_probs = kmf.survival_function_at_times(time_points).values
            
            for j, (time, prob) in enumerate(zip(time_points, survival_probs)):
                table_data.append({
                    'Group': label,
                    'Time': f"{time:.1f}",
                    'Survival': f"{prob:.3f}",
                    'x': j,
                    'y': len(kmfs) - i - 1
                })
        
        table_df = pd.DataFrame(table_data)
        
        table_plot = hv.Points(
            table_df, 
            kdims=['x', 'y'], 
            vdims=['Survival', 'Group', 'Time']
        ).opts(
            size=0,
            width=width,
            height=150,
            xaxis='top',
            xticks=[(i, f"{t:.1f}") for i, t in enumerate(time_points)],
            yticks=[(i, kmfs[len(kmfs)-1-i].label or f"Group {len(kmfs)-i}") 
                   for i in range(len(kmfs))],
            xlabel=xlabel or 'Time',
            ylabel='',
            tools=['hover'],
            hover_tooltips=[('Group', '@Group'), ('Time', '@Time'), ('Survival', '@Survival')]
        )
        
        text_labels = hv.Labels(
            table_df, 
            kdims=['x', 'y'], 
            vdims=['Survival']
        ).opts(
            text_font_size='10pt',
            text_align='center'
        )
        
        table_plot = table_plot * text_labels
        
        layout = (plot + table_plot).cols(1)
        
        if show:
            return layout
        else:
            return layout
    
    if show:
        return plot
    else:
        return plot

Zethson avatar May 22 '25 19:05 Zethson