ehrapy
ehrapy copied to clipboard
revamp existing plots
Description of feature
- [ ] Investigate the complexity of holoviz & discuss with @eroell and @Zethson
- [ ] Agree on a final plotting library with @eroell and @Zethson
- [ ] Investigate what the general plotting API should look like. Which object should be returned? Do we need to add an
ep.settings.plot_backendsetting? How are users supposed to show and save plots? - [ ] Reimplement existing plots in ehrapy that do NOT come from scanpy with the plotting library of choice. Starting points are the survival analysis plots
- [ ] Ensure that all of these plots are tested (check this for an example) and have a reproducibility notebook
- [ ] Generate a couple of example ehrdata_blobs datasets and investigate how they could be visualized. Discuss the proposals on zulip
- [ ] Implement the plots one by one
See also https://github.com/theislab/ehrapy/issues/663 & https://github.com/theislab/ehrapy/issues/666 & https://github.com/theislab/ehrapy/issues/232
The KMF implementation with holoviz. It's a pain to autogenerate that as LLMs are really stupid with holoviz.
import holoviews as hv
import numpy as np
import pandas as pd
hv.extension('bokeh')
def kaplan_meier(
kmfs,
*,
display_survival_statistics: bool = False,
ci_alpha: list[float] | None = None,
ci_force_lines: list[bool] | None = None,
ci_show: list[bool] | None = None,
ci_legend: list[bool] | None = None,
at_risk_counts: list[bool] | None = None,
color: list[str] | None = None,
grid: bool | None = False,
xlim: tuple[float, float] | None = None,
ylim: tuple[float, float] | None = None,
xlabel: str | None = None,
ylabel: str | None = None,
width: int = 800,
height: int = 400,
show: bool | None = None,
title: str | None = None,
) -> hv.Layout | hv.Overlay | None:
if not hasattr(kmfs, '__iter__') or isinstance(kmfs, str):
kmfs = [kmfs]
if ci_alpha is None:
ci_alpha = [0.3] * len(kmfs)
if ci_force_lines is None:
ci_force_lines = [False] * len(kmfs)
if ci_show is None:
ci_show = [True] * len(kmfs)
if ci_legend is None:
ci_legend = [False] * len(kmfs)
if at_risk_counts is None:
at_risk_counts = [False] * len(kmfs)
if color is None:
color = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
'#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'] * ((len(kmfs) // 10) + 1)
curves = []
for i, kmf in enumerate(kmfs):
timeline = kmf.timeline
survival = kmf.survival_function_.iloc[:, 0]
label = kmf.label if kmf.label else f"Group {i + 1}"
curve_data = [(t, s, label) for t, s in zip(timeline, survival)]
curve = hv.Curve(
curve_data,
kdims=['Time'],
vdims=['Survival', 'Group'],
label=label
).opts(
color=color[i],
line_width=2,
tools=['hover'],
hover_tooltips=[('Group', '@Group'), ('Time', '@Time{0.00}'), ('Survival', '@Survival{0.000}')]
)
if ci_show[i] and hasattr(kmf, 'confidence_interval_'):
ci_lower = kmf.confidence_interval_.iloc[:, 0]
ci_upper = kmf.confidence_interval_.iloc[:, 1]
if ci_force_lines[i]:
ci_lower_data = [(t, l, f"{label} CI Lower") for t, l in zip(timeline, ci_lower)]
ci_upper_data = [(t, u, f"{label} CI Upper") for t, u in zip(timeline, ci_upper)]
ci_lower_curve = hv.Curve(
ci_lower_data,
kdims=['Time'],
vdims=['CI_Lower', 'Group'],
label=f"{label} CI Lower" if ci_legend[i] else ""
).opts(
color=color[i],
line_dash='dashed',
alpha=ci_alpha[i],
tools=['hover'],
hover_tooltips=[('Group', '@Group'), ('Time', '@Time{0.00}'), ('CI Lower', '@CI_Lower{0.000}')]
)
ci_upper_curve = hv.Curve(
ci_upper_data,
kdims=['Time'],
vdims=['CI_Upper', 'Group'],
label=f"{label} CI Upper" if ci_legend[i] else ""
).opts(
color=color[i],
line_dash='dashed',
alpha=ci_alpha[i],
tools=['hover'],
hover_tooltips=[('Group', '@Group'), ('Time', '@Time{0.00}'), ('CI Upper', '@CI_Upper{0.000}')]
)
curve = curve * ci_lower_curve * ci_upper_curve
else:
area_data = []
for t, l in zip(timeline, ci_lower):
area_data.append((t, l))
for t, u in zip(timeline[::-1], ci_upper[::-1]):
area_data.append((t, u))
if area_data:
ci_area = hv.Polygons([area_data], label=f"{label} CI").opts(
fill_color=color[i],
fill_alpha=ci_alpha[i],
line_alpha=0,
tools=['hover'],
hover_tooltips=[('Group', f'{label}'), ('Confidence Interval', 'Yes')]
)
curve = ci_area * curve
curves.append(curve)
opts_dict = {
'width': width,
'height': height,
'xlabel': xlabel or 'Time',
'ylabel': ylabel or 'Survival Probability',
'show_grid': grid,
'legend_position': 'top_right',
'tools': ['pan', 'wheel_zoom', 'box_zoom', 'reset', 'save']
}
if title:
opts_dict['title'] = title
if xlim:
opts_dict['xlim'] = xlim
if ylim:
opts_dict['ylim'] = ylim
plot = hv.Overlay(curves).opts(**opts_dict)
if display_survival_statistics:
all_times = np.concatenate([kmf.timeline for kmf in kmfs])
time_points = np.linspace(all_times.min(), all_times.max(), 6)
table_data = []
for i, kmf in enumerate(kmfs):
label = kmf.label if kmf.label else f"Group {i + 1}"
survival_probs = kmf.survival_function_at_times(time_points).values
for j, (time, prob) in enumerate(zip(time_points, survival_probs)):
table_data.append({
'Group': label,
'Time': f"{time:.1f}",
'Survival': f"{prob:.3f}",
'x': j,
'y': len(kmfs) - i - 1
})
table_df = pd.DataFrame(table_data)
table_plot = hv.Points(
table_df,
kdims=['x', 'y'],
vdims=['Survival', 'Group', 'Time']
).opts(
size=0,
width=width,
height=150,
xaxis='top',
xticks=[(i, f"{t:.1f}") for i, t in enumerate(time_points)],
yticks=[(i, kmfs[len(kmfs)-1-i].label or f"Group {len(kmfs)-i}")
for i in range(len(kmfs))],
xlabel=xlabel or 'Time',
ylabel='',
tools=['hover'],
hover_tooltips=[('Group', '@Group'), ('Time', '@Time'), ('Survival', '@Survival')]
)
text_labels = hv.Labels(
table_df,
kdims=['x', 'y'],
vdims=['Survival']
).opts(
text_font_size='10pt',
text_align='center'
)
table_plot = table_plot * text_labels
layout = (plot + table_plot).cols(1)
if show:
return layout
else:
return layout
if show:
return plot
else:
return plot