pycox icon indicating copy to clipboard operation
pycox copied to clipboard

Incoporate Time-Varying Data in "CoxCCDataset" function

Open Thanksforhelp opened this issue 4 years ago • 3 comments

Hi Havakv, great to have a meeting with you yesterday, you showed us the right direction to incorporate time-varying survival data. I am trying to incorporate time-varying survival data (or counting process data) inside your pycox package using your built-in sample data "metabric"

1, I created two new variables representing the start of follow-up and end of follow-up for time-varying data, I used your sample data metabric and extended it to time-varying survival data: df_train = metabric.read_df() df_train["START"] = 0 df_train["STOP"] = df_train["duration"]

2, follow your suggestion, I am trying to modify the case-control selection part to select control based on start and stop time rather than the duration column:

class CoxCCDataset_V2(torch.utils.data.Dataset): def init(self, input, durations, events, n_control=1):

    df_train_target = pd.DataFrame(dict(duration=durations, event=events, START=START, STOP=STOP))
    new_df = pd.DataFrame()  # dataframe for matched controls 
    case_df = df_train_target.loc[lambda x: x['event'] == 1]['duration']  # dataframe for cases 
    for i in range(0, len(df_train_target)):
        if (df_train_target.iloc[i]['event'] == 1):
            control_df = pd.DataFrame()  # dataframe for controls
            control_df = df_train_target[df_train_target['event'] == 0]
            control_df = control_df[control_df['START'] <= df_train_target.iloc[i]['STOP']]  
            control_df = control_df[control_df['STOP'] >= df_train_target.iloc[i]['STOP']]  # interval filter
            control = control_df.sample()
            new_df = new_df.append(control)
    return new_df

    self.durations = df_train_target.loc[lambda x: x['event'] == 1]['duration']
    self.input = tt.tuplefy(input)
    assert type(self.durations) is pd.Series
    self.n_control = n_control

def __getitem__(self, index):
    if (not hasattr(index, '__iter__')) and (type(index) is not slice):
        index = [index]
    fails = self.case_df.loc[lambda x: x['event'] == 1]['duration'].iloc[index]
    x_case = self.input.iloc[fails.index]
    non_fails = self.new_df.loc[lambda x: x['event'] == 0]['duration'].iloc[index]
    x_control = self.input.iloc[non_fails.index]
    return tt.tuplefy(x_case, x_control).to_tensor()

def __len__(self):
    return len(self.durations)

I am not familiar PyTorch or Torch, can't debug which part that I wrote is wrong. If you have time, can you check the function above? Thanks!

Thanksforhelp avatar Jun 29 '21 17:06 Thanksforhelp

Thank's for the talk! I'm not sure I will be able to have a look at this before August as I've started my vacation. But immediately I can see that you have a return new_df in the __init__ function which should probably don't be there as there rest of the init function is not run. Try removing that

havakv avatar Jul 05 '21 09:07 havakv

Hi Havakv, enjoy your vacation.

I am able to get the CoxCCDataset function to work now, it is able to do the selection for time-fixed survival data and time-varying survival data.

A new issue from torchtuples:

~\anaconda3\lib\site-packages\torchtuples\base.py in fit_dataloader(self, dataloader, epochs, callbacks, verbose, metrics, val_dataloader) 227 if stop: break 228 self.optimizer.zero_grad() --> 229 self.batch_metrics = self.compute_metrics(data, self.metrics) 230 self.batch_loss = self.batch_metrics['loss'] 231 self.batch_loss.backward()

in compute_metrics(self, input, metrics) 54 control_length = [len(a) for a in control] 55 ---> 56 input_all = tt.TupleTree((case,) + control).cat() 57 g_all = self.net(*input_all) 58 g_all = tt.tuplefy(g_all).split(batch_size).flatten()

~\anaconda3\lib\site-packages\torchtuples\tupletree.py in cat(self, dim) 420 @docstring(cat) 421 def cat(self, dim=0): --> 422 return cat(self, dim) 423 424 def reduce_nrec(self, func, initial=None):

~\anaconda3\lib\site-packages\torchtuples\tupletree.py in cat(seq, dim) 202 """ 203 if not seq.shapes().apply(lambda x: x[1:]).all_equal(): --> 204 raise ValueError("Shapes of merged arrays need to be the same") 205 206 type_ = seq.type()

ValueError: Shapes of merged arrays need to be the same

I include below validation part inside compute_metrics to see if the shape of two tuples is the same, no error jumps out, which means they have the same length but I still get the same error.

    case, control = input # both are TupleTree
    if len(case) != len(control):
        raise RuntimeError("case length is not equal to control length")
    if [len(a) for a in case] != [len(a) for a in control]:
        raise RuntimeError("case shape is not equal to control shape")

Any thoughts on this error?

Thanksforhelp avatar Jul 06 '21 15:07 Thanksforhelp

So the case and control are TupleTree object (you can think of them as tuples), not tensors. So the len is just the length of the tuples which is not the issue here. The control is a collection of controls (though it typically only includes one control). Try printing case.levels and control.levels. These should be (0) and ((1)) respectively. My guess is that you will get (0) for both. If that is the case, you can change your code in __getitem__ to add a level with add_root(), so from your code in the first message:

x_control = self.input.iloc[non_fails.index]
x_control = x_control.add_root() 
return tt.tuplefy(x_case, x_control).to_tensor()

havakv avatar Jul 06 '21 16:07 havakv