gbnet icon indicating copy to clipboard operation
gbnet copied to clipboard

Add Survival Support

Open mthorrell opened this issue 1 year ago • 1 comments

Raw code to be optimized

import numpy as np
import pandas as pd

from gboost_module import xgbmodule

def meas_df_to_long_df(
    meas_df: pd.DataFrame,
    resolution: int
):
    assert 'time' in meas_df
    assert 'id' in meas_df
    assert 'event' in meas_df  # conceptually this is somewhat optional
    assert 'max_obs_time' not in meas_df  # don't have columns that would conflict with working columns

    meas_df = meas_df.sort_values(['id', 'time']).copy()

    non_zero_ids = [
        uid for uid in meas_df['id'].drop_duplicates()
        if not (meas_df[meas_df['id'] == uid]['time'] == 0).any()
    ]
    new_rows = []
    for uid in non_zero_ids:
        new_row = meas_df[meas_df['id'] == uid].iloc[0].copy()
        new_row['time'] = 0
        new_row['event'] = False
        new_rows.append(new_row.to_frame().T)
    meas_df = pd.concat(new_rows + [meas_df]).sort_values(['id', 'time']).copy()
        
    
    meas_df = meas_df.merge(
        meas_df.groupby('id')['time'].max().rename('max_obs_time').reset_index(),
        on='id',
        how='inner',
        validate='many_to_one'
    )
    
    meas_df = meas_df.sort_values(['id', 'time'])

    grid = pd.concat([
        pd.Series(np.linspace(
            meas_df['time'].min(),
            meas_df['time'].max(),
            resolution
        )),
        meas_df['time'].copy()
    ]).drop_duplicates().sort_values().to_list()
    
    long_df = []
    for t in grid:
        sub_df = meas_df[
            (meas_df['time'] <= t) & (meas_df['max_obs_time'] >= t)
        ].sort_values('time', ascending=False).drop_duplicates('id').copy()
        sub_df['measured_time'] = t
        long_df.append(sub_df)
    long_df = pd.concat(long_df).reset_index(drop=True)

    return long_df.sort_values(['id', 'measured_time']).reset_index(drop=True).copy()


class IntegrateSurv(torch.nn.Module):
    def __init__(self, meas_df, resolution, covariates=[], params={}):
        super(IntegrateSurv, self).__init__()
        
        self.meas_df = meas_df
        self.resolution = resolution
        self.model_cols = covariates + ['measured_time']

        self.long_df, self.id_stats = self.get_input_info_from_meas_df(meas_df)
        
        self.gb = xgbmodule.XGBModule(
            self.long_df.shape[0], len(self.model_cols), 1, params=params
        )
        self.base_lambda = torch.nn.Parameter(torch.Tensor([0.0]))

        self.test_long_df = None
        self.test_id_stats = None

    def get_input_info_from_meas_df(self, meas_df):
        long_df = meas_df_to_long_df(meas_df, self.resolution).sort_values(
            ['id', 'measured_time']
        ).reset_index(drop=True).copy()

        id_stats = long_df.groupby('id').apply(
            lambda x: pd.Series({
                'n_preds': x.shape[0],
                'event': x['event'].any()
            })
        ).reset_index().copy()
        id_stats['preds_idx'] = id_stats['n_preds'].cumsum()

        return long_df, id_stats
        
        

    def forward(self, long_df=None, use_test_cache=True):
        assert long_df is None or not self.training
        if long_df is not None:
            if not use_test_cache or self.test_long_df is None or self.test_id_stats is None:
                self.test_long_df, self.test_id_stats = self.get_input_info_from_meas_df(
                    long_df
                )

        gb_input = (
            self.long_df[self.model_cols].astype(float) if self.training
            else self.test_long_df[self.model_cols].astype(float)
        )
        hazard_flat = torch.exp(
            self.gb(gb_input)
            + self.base_lambda
        )

        # I wonder if we strategicaly insert zeros here so we can vectorize this
        # thing better
        hazard_by_id = {}
        prev_idx = 0
        all_loss = torch.Tensor([0.0])

        id_info = (
            self.id_stats if self.training
            else self.test_id_stats
        )
        
        for i, row in id_info.iterrows():
            low_idx = prev_idx
            high_idx = row['preds_idx']

            row_data = {
                'raw_hazard': hazard_flat[low_idx:high_idx,:].flatten(),
                'integrate_hazard': torch.trapezoid(
                    hazard_flat[low_idx:high_idx,:].flatten(),
                    torch.Tensor(
                        np.array(gb_input.iloc[low_idx:high_idx, :]['measured_time'])
                    ).flatten()
                )
            }
            prev_idx = row['preds_idx']
            
            row_data['prob_hit'] = row_data['raw_hazard'][-1]

            row_data['loss'] = row_data['integrate_hazard']
            if row['event']:
                row_data['loss'] = row_data['loss'] - torch.log(row_data['prob_hit'])
            all_loss = all_loss + row_data['loss']
            hazard_by_id[row['id']] = row_data
            

        all_loss = all_loss / id_info.shape[0]
        return hazard_by_id, all_loss

    def gb_step(self):
        self.gb.gb_step(self.long_df[self.model_cols].astype(float))

Example dataset

import pandas as pd
input_df = pd.DataFrame([
    {'id': 1, 'time': 0, 'event': False, 'covariate_1': 1},
    {'id': 1, 'time': 4, 'event': False, 'covariate_1': 3},
    {'id': 1, 'time': 6, 'event': True, 'covariate_1': None},
    {'id': 2, 'time': 0, 'event': False, 'covariate_1': 1},
    {'id': 2, 'time': 3, 'event': False, 'covariate_1': None},
    {'id': 2, 'time': 5, 'event': False, 'covariate_1': 2},
])

gboost_module vs. Kaplan Meier when using no covariates on fake data: image

mthorrell avatar Jul 01 '24 03:07 mthorrell

UPDATE -- going to go a general integration route, so we can support Poisson processes as well

from __future__ import annotations

import warnings
from typing import List, Dict, Any

import pandas as pd
import torch
from torch import nn
from gbnet.xgbmodule import XGBModule


class HazardIntegrator(nn.Module):
    def __init__(self, covariate_cols: List[str] = [], params={}, min_hess=0):
        """
        Parameters
        ----------
        covariate_cols
            Columns to feed into the model. "time" is always included.
        """
        super().__init__()
        self.params = params.copy()
        self.min_hess = min_hess
        self.covariate_cols = covariate_cols
        self.gb_module = None

    def forward(self, df: pd.DataFrame) -> Dict[str, Any]:
        # TODO - freeze the training set after it gets read once
        # to avoid significant overhead in generating it for each
        # round of training
        if {"unit_id", "time"} - set(df.columns):
            raise ValueError("DataFrame must contain 'unit_id' and 'time'.")

        # 1. Sort (unit_id, time) – stable so original relative order inside ties is kept
        df_sorted = df.sort_values(["unit_id", "time"], kind="mergesort").reset_index()
        orig_idx = torch.as_tensor(df_sorted.pop("index").values, dtype=torch.long)

        # 2. Build tensors -------------------------------------------------
        unit_codes, uniques = pd.factorize(df_sorted["unit_id"])
        unit_ids = torch.as_tensor(unit_codes, dtype=torch.long)
        times = torch.as_tensor(df_sorted["time"].values, dtype=torch.float32)

        # Select covariate columns
        covar_cols = ['time'] + self.covariate_cols
        X = df_sorted[covar_cols].values

        if self.gb_module is None:
            self.gb_module = XGBModule(X.shape[0], len(covar_cols), 1, params=self.params, min_hess=self.min_hess)
        
        # 3. Model inference ----------------------------------------------
        log_hazard = self.gb_module(X).flatten()  # [N]
        hazard = torch.exp(log_hazard)            # λ(t)

        # 4. Trapezoidal slice per row ------------------------------------
        dt = torch.diff(times, prepend=times.new_zeros(1))
        same_unit = torch.diff(unit_ids, prepend=unit_ids.new_tensor([-1])).eq(0)
        dt = dt * same_unit.float()

        trapz_slice = 0.5 * (hazard + torch.roll(hazard, 1)) * dt
        trapz_slice[~same_unit] = 0.0

        unit_Lambda = torch.zeros(
            [df['unit_id'].drop_duplicates().shape[0]]
        ).scatter_reduce(0, unit_ids, trapz_slice, reduce='sum')

        # this mask can be frozen (ie not regenerated each training round)
        mask = torch.cat((
            unit_ids[1:] != unit_ids[:-1],        
            torch.tensor([True], device=unit_ids.device)
        ))
        last_hazard = hazard[mask]   
        
        unsort = torch.argsort(orig_idx)
        hazard_unsort = hazard[unsort]

        return {
            "hazard": hazard_unsort,
            "unit_last_hazard": last_hazard,
            "unit_integrated_hazard": unit_Lambda,
        }

    def gb_step(self):
        self.gb_module.gb_step()

HELPER:

import pandas as pd
import numpy as np

def expand_overlapping_units(
    df: pd.DataFrame,
    unit_col: str = "unit_id",
    time_col: str = "time",
    fill_value=np.nan,          # use None if you prefer plain Python Nones
):
    """
    Ensure that every unit appearing in *df* contains **all** time stamps that lie
    (a) between that unit’s own min & max timestamps *and*
    (b) also occur in *any* other unit.

    Missing rows are added and every non‑key column is filled with *fill_value*.

    Parameters
    ----------
    df : pd.DataFrame
        Original long‑format data.
    unit_col : str
        Column name that identifies the unit / subject.
    time_col : str
        Column name containing the time stamps.
    fill_value : scalar
        Value used to fill non‑key columns for rows we synthesize
        (defaults to NaN).

    Returns
    -------
    pd.DataFrame
        Expanded, chronologically sorted long‑format data.
    """
    # Unique times observed anywhere in the data, sorted for stability
    all_times = np.sort(df[time_col].unique())

    # Min & max time for each unit, so we know that unit’s active span
    t_min = df.groupby(unit_col)[time_col].min()
    t_max = df.groupby(unit_col)[time_col].max()

    # Build the full “skeleton” of unit–time combinations we need
    pieces = []
    for unit in t_min.index:
        mask = (all_times >= t_min[unit]) & (all_times <= t_max[unit])
        pieces.append(
            pd.DataFrame(
                {unit_col: unit, time_col: all_times[mask]}
            )
        )
    skeleton = pd.concat(pieces, ignore_index=True)

    # Left‑join the user’s data onto the skeleton; missing values → fill_value
    out = (
        skeleton
        .merge(df, on=[unit_col, time_col], how="left")
        .sort_values([unit_col, time_col], kind="mergesort")
        .reset_index(drop=True)
        .fillna(fill_value)
    )
    return out

KM curve example:

import numpy as np
import pandas as pd

np.random.seed(42)
n = 2000

# Simulate "sine-modulated" hazard via inverse transform
u = np.random.uniform(0, 1, size=n)

# Oscillating hazard: higher chance of event in some regions
hazard_wave = lambda t: 0.02 + 0.015 * np.sin(t * 3)  # time-varying hazard

# Generate time grid for rejection sampling
time_grid = np.linspace(0.1, 20, 10000)
hazard_vals = hazard_wave(time_grid)
cdf = np.cumsum(hazard_vals)
cdf /= cdf[-1]  # normalize

# Sample event times via inverse CDF
event_times = np.interp(u, cdf, time_grid)

# Add censoring times (mostly after event times)
censoring_times = event_times + np.random.exponential(scale=5, size=n)
observed_times = np.minimum(event_times, censoring_times)
events = (event_times <= censoring_times).astype(int)

df_wave = pd.DataFrame({
    'time': observed_times,
    'event': events
})

from lifelines import KaplanMeierFitter
import matplotlib.pyplot as plt

kmf = KaplanMeierFitter()
kmf.fit(df_wave["time"], event_observed=df_wave["event"])

kmf.plot_survival_function()
plt.title("Wiggly Kaplan–Meier Curve (Quasi-Continuous)")
plt.xlabel("Time")
plt.ylabel("Survival Probability")
plt.grid(True)
plt.show()

dataset conversions

def to_integration_df(df):
    
    df['unit_id'] = range(len(df))
    unit_metadata = df[['unit_id', 'event']]
    return (
        pd.concat([
            df,
            pd.DataFrame([{'unit_id': i, 'time': 0} for i in range(df.shape[0])])
        ])[['unit_id', 'time']].merge(
            unit_metadata,
            on='unit_id', how='inner',
            validate='many_to_one'
        ).sort_values(['unit_id', 'time']).reset_index(drop=True),
        df
    )


edf, mdf = to_integration_df(df_wave)

TRAINING CODE

integrator = HazardIntegrator()
exp_df = expand_overlapping_units(edf)

losses = []
for i in range(500):
    
    integrator.zero_grad()
    
    out = integrator(exp_df)
    
    # n log-lik loss
    loss = (
        out['unit_integrated_hazard'].sum() - (torch.log(out['unit_last_hazard']) * torch.Tensor(mdf['event'] == 1)).sum()
    ) / mdf.shape[0]
    losses.append(loss.item())

    loss.backward(create_graph=True)
    
    integrator.gb_step()

mthorrell avatar Jul 26 '25 21:07 mthorrell