gbnet
gbnet copied to clipboard
Add Survival Support
Raw code to be optimized
import numpy as np
import pandas as pd
from gboost_module import xgbmodule
def meas_df_to_long_df(
meas_df: pd.DataFrame,
resolution: int
):
assert 'time' in meas_df
assert 'id' in meas_df
assert 'event' in meas_df # conceptually this is somewhat optional
assert 'max_obs_time' not in meas_df # don't have columns that would conflict with working columns
meas_df = meas_df.sort_values(['id', 'time']).copy()
non_zero_ids = [
uid for uid in meas_df['id'].drop_duplicates()
if not (meas_df[meas_df['id'] == uid]['time'] == 0).any()
]
new_rows = []
for uid in non_zero_ids:
new_row = meas_df[meas_df['id'] == uid].iloc[0].copy()
new_row['time'] = 0
new_row['event'] = False
new_rows.append(new_row.to_frame().T)
meas_df = pd.concat(new_rows + [meas_df]).sort_values(['id', 'time']).copy()
meas_df = meas_df.merge(
meas_df.groupby('id')['time'].max().rename('max_obs_time').reset_index(),
on='id',
how='inner',
validate='many_to_one'
)
meas_df = meas_df.sort_values(['id', 'time'])
grid = pd.concat([
pd.Series(np.linspace(
meas_df['time'].min(),
meas_df['time'].max(),
resolution
)),
meas_df['time'].copy()
]).drop_duplicates().sort_values().to_list()
long_df = []
for t in grid:
sub_df = meas_df[
(meas_df['time'] <= t) & (meas_df['max_obs_time'] >= t)
].sort_values('time', ascending=False).drop_duplicates('id').copy()
sub_df['measured_time'] = t
long_df.append(sub_df)
long_df = pd.concat(long_df).reset_index(drop=True)
return long_df.sort_values(['id', 'measured_time']).reset_index(drop=True).copy()
class IntegrateSurv(torch.nn.Module):
def __init__(self, meas_df, resolution, covariates=[], params={}):
super(IntegrateSurv, self).__init__()
self.meas_df = meas_df
self.resolution = resolution
self.model_cols = covariates + ['measured_time']
self.long_df, self.id_stats = self.get_input_info_from_meas_df(meas_df)
self.gb = xgbmodule.XGBModule(
self.long_df.shape[0], len(self.model_cols), 1, params=params
)
self.base_lambda = torch.nn.Parameter(torch.Tensor([0.0]))
self.test_long_df = None
self.test_id_stats = None
def get_input_info_from_meas_df(self, meas_df):
long_df = meas_df_to_long_df(meas_df, self.resolution).sort_values(
['id', 'measured_time']
).reset_index(drop=True).copy()
id_stats = long_df.groupby('id').apply(
lambda x: pd.Series({
'n_preds': x.shape[0],
'event': x['event'].any()
})
).reset_index().copy()
id_stats['preds_idx'] = id_stats['n_preds'].cumsum()
return long_df, id_stats
def forward(self, long_df=None, use_test_cache=True):
assert long_df is None or not self.training
if long_df is not None:
if not use_test_cache or self.test_long_df is None or self.test_id_stats is None:
self.test_long_df, self.test_id_stats = self.get_input_info_from_meas_df(
long_df
)
gb_input = (
self.long_df[self.model_cols].astype(float) if self.training
else self.test_long_df[self.model_cols].astype(float)
)
hazard_flat = torch.exp(
self.gb(gb_input)
+ self.base_lambda
)
# I wonder if we strategicaly insert zeros here so we can vectorize this
# thing better
hazard_by_id = {}
prev_idx = 0
all_loss = torch.Tensor([0.0])
id_info = (
self.id_stats if self.training
else self.test_id_stats
)
for i, row in id_info.iterrows():
low_idx = prev_idx
high_idx = row['preds_idx']
row_data = {
'raw_hazard': hazard_flat[low_idx:high_idx,:].flatten(),
'integrate_hazard': torch.trapezoid(
hazard_flat[low_idx:high_idx,:].flatten(),
torch.Tensor(
np.array(gb_input.iloc[low_idx:high_idx, :]['measured_time'])
).flatten()
)
}
prev_idx = row['preds_idx']
row_data['prob_hit'] = row_data['raw_hazard'][-1]
row_data['loss'] = row_data['integrate_hazard']
if row['event']:
row_data['loss'] = row_data['loss'] - torch.log(row_data['prob_hit'])
all_loss = all_loss + row_data['loss']
hazard_by_id[row['id']] = row_data
all_loss = all_loss / id_info.shape[0]
return hazard_by_id, all_loss
def gb_step(self):
self.gb.gb_step(self.long_df[self.model_cols].astype(float))
Example dataset
import pandas as pd
input_df = pd.DataFrame([
{'id': 1, 'time': 0, 'event': False, 'covariate_1': 1},
{'id': 1, 'time': 4, 'event': False, 'covariate_1': 3},
{'id': 1, 'time': 6, 'event': True, 'covariate_1': None},
{'id': 2, 'time': 0, 'event': False, 'covariate_1': 1},
{'id': 2, 'time': 3, 'event': False, 'covariate_1': None},
{'id': 2, 'time': 5, 'event': False, 'covariate_1': 2},
])
gboost_module vs. Kaplan Meier when using no covariates on fake data:
UPDATE -- going to go a general integration route, so we can support Poisson processes as well
from __future__ import annotations
import warnings
from typing import List, Dict, Any
import pandas as pd
import torch
from torch import nn
from gbnet.xgbmodule import XGBModule
class HazardIntegrator(nn.Module):
def __init__(self, covariate_cols: List[str] = [], params={}, min_hess=0):
"""
Parameters
----------
covariate_cols
Columns to feed into the model. "time" is always included.
"""
super().__init__()
self.params = params.copy()
self.min_hess = min_hess
self.covariate_cols = covariate_cols
self.gb_module = None
def forward(self, df: pd.DataFrame) -> Dict[str, Any]:
# TODO - freeze the training set after it gets read once
# to avoid significant overhead in generating it for each
# round of training
if {"unit_id", "time"} - set(df.columns):
raise ValueError("DataFrame must contain 'unit_id' and 'time'.")
# 1. Sort (unit_id, time) – stable so original relative order inside ties is kept
df_sorted = df.sort_values(["unit_id", "time"], kind="mergesort").reset_index()
orig_idx = torch.as_tensor(df_sorted.pop("index").values, dtype=torch.long)
# 2. Build tensors -------------------------------------------------
unit_codes, uniques = pd.factorize(df_sorted["unit_id"])
unit_ids = torch.as_tensor(unit_codes, dtype=torch.long)
times = torch.as_tensor(df_sorted["time"].values, dtype=torch.float32)
# Select covariate columns
covar_cols = ['time'] + self.covariate_cols
X = df_sorted[covar_cols].values
if self.gb_module is None:
self.gb_module = XGBModule(X.shape[0], len(covar_cols), 1, params=self.params, min_hess=self.min_hess)
# 3. Model inference ----------------------------------------------
log_hazard = self.gb_module(X).flatten() # [N]
hazard = torch.exp(log_hazard) # λ(t)
# 4. Trapezoidal slice per row ------------------------------------
dt = torch.diff(times, prepend=times.new_zeros(1))
same_unit = torch.diff(unit_ids, prepend=unit_ids.new_tensor([-1])).eq(0)
dt = dt * same_unit.float()
trapz_slice = 0.5 * (hazard + torch.roll(hazard, 1)) * dt
trapz_slice[~same_unit] = 0.0
unit_Lambda = torch.zeros(
[df['unit_id'].drop_duplicates().shape[0]]
).scatter_reduce(0, unit_ids, trapz_slice, reduce='sum')
# this mask can be frozen (ie not regenerated each training round)
mask = torch.cat((
unit_ids[1:] != unit_ids[:-1],
torch.tensor([True], device=unit_ids.device)
))
last_hazard = hazard[mask]
unsort = torch.argsort(orig_idx)
hazard_unsort = hazard[unsort]
return {
"hazard": hazard_unsort,
"unit_last_hazard": last_hazard,
"unit_integrated_hazard": unit_Lambda,
}
def gb_step(self):
self.gb_module.gb_step()
HELPER:
import pandas as pd
import numpy as np
def expand_overlapping_units(
df: pd.DataFrame,
unit_col: str = "unit_id",
time_col: str = "time",
fill_value=np.nan, # use None if you prefer plain Python Nones
):
"""
Ensure that every unit appearing in *df* contains **all** time stamps that lie
(a) between that unit’s own min & max timestamps *and*
(b) also occur in *any* other unit.
Missing rows are added and every non‑key column is filled with *fill_value*.
Parameters
----------
df : pd.DataFrame
Original long‑format data.
unit_col : str
Column name that identifies the unit / subject.
time_col : str
Column name containing the time stamps.
fill_value : scalar
Value used to fill non‑key columns for rows we synthesize
(defaults to NaN).
Returns
-------
pd.DataFrame
Expanded, chronologically sorted long‑format data.
"""
# Unique times observed anywhere in the data, sorted for stability
all_times = np.sort(df[time_col].unique())
# Min & max time for each unit, so we know that unit’s active span
t_min = df.groupby(unit_col)[time_col].min()
t_max = df.groupby(unit_col)[time_col].max()
# Build the full “skeleton” of unit–time combinations we need
pieces = []
for unit in t_min.index:
mask = (all_times >= t_min[unit]) & (all_times <= t_max[unit])
pieces.append(
pd.DataFrame(
{unit_col: unit, time_col: all_times[mask]}
)
)
skeleton = pd.concat(pieces, ignore_index=True)
# Left‑join the user’s data onto the skeleton; missing values → fill_value
out = (
skeleton
.merge(df, on=[unit_col, time_col], how="left")
.sort_values([unit_col, time_col], kind="mergesort")
.reset_index(drop=True)
.fillna(fill_value)
)
return out
KM curve example:
import numpy as np
import pandas as pd
np.random.seed(42)
n = 2000
# Simulate "sine-modulated" hazard via inverse transform
u = np.random.uniform(0, 1, size=n)
# Oscillating hazard: higher chance of event in some regions
hazard_wave = lambda t: 0.02 + 0.015 * np.sin(t * 3) # time-varying hazard
# Generate time grid for rejection sampling
time_grid = np.linspace(0.1, 20, 10000)
hazard_vals = hazard_wave(time_grid)
cdf = np.cumsum(hazard_vals)
cdf /= cdf[-1] # normalize
# Sample event times via inverse CDF
event_times = np.interp(u, cdf, time_grid)
# Add censoring times (mostly after event times)
censoring_times = event_times + np.random.exponential(scale=5, size=n)
observed_times = np.minimum(event_times, censoring_times)
events = (event_times <= censoring_times).astype(int)
df_wave = pd.DataFrame({
'time': observed_times,
'event': events
})
from lifelines import KaplanMeierFitter
import matplotlib.pyplot as plt
kmf = KaplanMeierFitter()
kmf.fit(df_wave["time"], event_observed=df_wave["event"])
kmf.plot_survival_function()
plt.title("Wiggly Kaplan–Meier Curve (Quasi-Continuous)")
plt.xlabel("Time")
plt.ylabel("Survival Probability")
plt.grid(True)
plt.show()
dataset conversions
def to_integration_df(df):
df['unit_id'] = range(len(df))
unit_metadata = df[['unit_id', 'event']]
return (
pd.concat([
df,
pd.DataFrame([{'unit_id': i, 'time': 0} for i in range(df.shape[0])])
])[['unit_id', 'time']].merge(
unit_metadata,
on='unit_id', how='inner',
validate='many_to_one'
).sort_values(['unit_id', 'time']).reset_index(drop=True),
df
)
edf, mdf = to_integration_df(df_wave)
TRAINING CODE
integrator = HazardIntegrator()
exp_df = expand_overlapping_units(edf)
losses = []
for i in range(500):
integrator.zero_grad()
out = integrator(exp_df)
# n log-lik loss
loss = (
out['unit_integrated_hazard'].sum() - (torch.log(out['unit_last_hazard']) * torch.Tensor(mdf['event'] == 1)).sum()
) / mdf.shape[0]
losses.append(loss.item())
loss.backward(create_graph=True)
integrator.gb_step()