auton-survival
auton-survival copied to clipboard
Perform some validation of input to models
Some basic validation should be performed on the input, i.e. checking for NaN or proper datatype.
For example, a nan value in the time (event duration) array generates an obscure error (I discovered it because of a bug in my data preprocessing).
To replicate:
x, t, e = datasets.load_dataset('PBC')
model = DeepSurvivalMachines()
t[-1] = np.nan
model.fit(x, t, e)
Generates the following error:
File "test.py", line 17, in <module>
model.fit(x, t, e)
File "auton_survival/models/dsm/__init__.py", line 257, in fit
model, _ = train_dsm(model,
File "auton_survival/models/dsm/utilities.py", line 132, in train_dsm
premodel = pretrain_dsm(model,
File "auton_survival/models/dsm/utilities.py", line 73, in pretrain_dsm
loss += unconditional_loss(premodel, t_train, e_train, str(r+1))
File "auton_survival/models/dsm/losses.py", line 121, in unconditional_loss
return _weibull_loss(model, t, e, risk)
File "auton_survival/models/dsm/losses.py", line 113, in _weibull_loss
ll += f[uncens].sum() + s[cens].sum()
IndexError: index 1653 is out of bounds for dimension 0 with size 1653