auton-survival
auton-survival copied to clipboard
Error in the example notebook: 'RandomSurvivalForest' object has no attribute 'event_times_'
There is an arror in the RSF model in this notebook Survival Regression with Auton-Survival.ipynb
Here's the reproducible code
import pandas as pd
import sys
sys.path.append('../')
from auton_survival.datasets import load_dataset
Load data and features
# Load the SUPPORT dataset
outcomes, features = load_dataset(dataset='SUPPORT')
# Identify categorical (cat_feats) and continuous (num_feats) features
cat_feats = ['sex', 'dzgroup', 'dzclass', 'income', 'race', 'ca']
num_feats = ['age', 'num.co', 'meanbp', 'wblc', 'hrt', 'resp',
'temp', 'pafi', 'alb', 'bili', 'crea', 'sod', 'ph',
'glucose', 'bun', 'urine', 'adlp', 'adls']
Preprocess
import numpy as np
from sklearn.model_selection import train_test_split
# Split the SUPPORT data into training, validation, and test data
x_tr, x_te, y_tr, y_te = train_test_split(features, outcomes, test_size=0.2, random_state=1)
x_tr, x_val, y_tr, y_val = train_test_split(x_tr, y_tr, test_size=0.25, random_state=1)
print(f'Number of training data points: {len(x_tr)}')
print(f'Number of validation data points: {len(x_val)}')
print(f'Number of test data points: {len(x_te)}')
from auton_survival.preprocessing import Preprocessor
# Fit the imputer and scaler to the training data and transform the training, validation and test data
preprocessor = Preprocessor(cat_feat_strat='ignore', num_feat_strat= 'mean')
transformer = preprocessor.fit(features, cat_feats=cat_feats, num_feats=num_feats,
one_hot=True, fill_value=-1)
x_tr = transformer.transform(x_tr)
x_val = transformer.transform(x_val)
x_te = transformer.transform(x_te)
Fit RSF
from auton_survival.estimators import SurvivalModel
from auton_survival.metrics import survival_regression_metric
from sklearn.model_selection import ParameterGrid
# Define parameters for tuning the model
param_grid = {'n_estimators' : [100, 300],
'max_depth' : [3, 5],
'max_features' : ['sqrt', 'log2']
}
params = ParameterGrid(param_grid)
# Define the times for tuning the model hyperparameters and for evaluating the model
times = np.quantile(y_tr['time'][y_tr['event']==1], np.linspace(0.1, 1, 10)).tolist()
# Perform hyperparameter tuning
models = []
for param in params:
model = SurvivalModel('rsf', random_seed=8, n_estimators=param['n_estimators'], max_depth=param['max_depth'], max_features=param['max_features'])
# The fit method is called to train the model
model.fit(x_tr, y_tr)
# Obtain survival probabilities for validation set and compute the Integrated Brier Score
predictions_val = model.predict_survival(x_val, times)
metric_val = survival_regression_metric('ibs', y_val, predictions_val, times, y_tr)
models.append([metric_val, model])
# Select the best model based on the mean metric value computed for the validation set
metric_vals = [i[0] for i in models]
first_min_idx = metric_vals.index(min(metric_vals))
model = models[first_min_idx][1]
AttributeError
AttributeError Traceback (most recent call last)
File <command-1063050024682857>:25
22 model.fit(x_tr, y_tr)
24 # Obtain survival probabilities for validation set and compute the Integrated Brier Score
---> 25 predictions_val = model.predict_survival(x_val, times)
26 metric_val = survival_regression_metric('ibs', y_val, predictions_val, times, y_tr)
27 models.append([metric_val, model])
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-677c5cec-55bc-4bf2-85ea-946a24e7ad0b/lib/python3.9/site-packages/auton_survival/estimators.py:699, in SurvivalModel.predict_survival(self, features, times)
697 return _predict_cph(self._model, features, times)
698 elif self.model == 'rsf':
--> 699 return _predict_rsf(self._model, features, times)
700 elif self.model == 'dsm':
701 return _predict_dsm(self._model, features, times)
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-677c5cec-55bc-4bf2-85ea-946a24e7ad0b/lib/python3.9/site-packages/auton_survival/estimators.py:477, in _predict_rsf(model, features, times)
472 times = [float(times)]
474 survival_predictions = model.predict_survival_function(features.values,
475 return_array=True)
476 survival_predictions = pd.DataFrame(survival_predictions,
--> 477 columns=model.event_times_).T
479 return __interpolate_missing_times(survival_predictions, times)
AttributeError: 'RandomSurvivalForest' object has no attribute 'event_times_'
I think I've found the solution.
On line 477 in estimators.py
change from
columns=model.event_times_).T
to
columns=model.unique_times_).T
Is that correct?