wtte-rnn
wtte-rnn copied to clipboard
Numerical instability parameterization tricks
First of all I just want to say that your WTTE is really cool. Great blog post and paper. I've been using an adapted version of it for a time-to-event task and wanted to share a trick I've found useful for numerical instability issues in case you or anyone else is interested.
A couple of things to note in my case:
- I'm not using an RNN, since I have sufficient engineered features for the history at a point in time.
- I rewrote it in pytorch, so my code here is in pytorch.
- My case uses the discrete likelihood. I haven't tested anything for the continuous case but I don't see why it wouldn't work there too.
While testing it I had a lot of issues with nan loss and numeric instability during the fit of alpha
and beta
. I know you've worked a lot on this from reading the other github issues.
I've found that this parameterization for alpha
and beta
helps a lot:
class WTTE(nn.Module):
def __init__(self, nnet_output_dim):
super(WTTE, self).__init__()
# this is the neural net whose outputs then are used to find alpha and beta
self.nnet = InnerNNET()
self.softplus = nn.Softplus()
self.tanh = nn.Tanh()
self.alpha_scaling = nn.Linear(nnet_output_dim, 1)
self.beta_scaling = nn.Linear(nnet_output_dim, 1)
# offset and scale parameters
alpha_offset_init, beta_offset_init = 1.0, 1.0
alpha_scale_init, beta_scale_init = 1.0, 1.0
self.alpha_offset = nn.Parameter(tt.from_numpy(np.array([alpha_offset_init])).float(), requires_grad=True)
self.beta_offset = nn.Parameter(tt.from_numpy(np.array([beta_offset_init])).float(), requires_grad=True)
self.alpha_scale = nn.Parameter(tt.from_numpy(np.array([alpha_scale_init])).float(), requires_grad=True)
self.beta_scale = nn.Parameter(tt.from_numpy(np.array([beta_scale_init])).float(), requires_grad=True)
def forward(self, x):
x = self.nnet(x)
# derive alpha and beta individual scaling factors
a_scaler = self.alpha_scaling(x)
b_scaler = self.beta_scaling(x)
# enforce the scaling factors to be between -1 and 1
a_scaler = self.tanh(a_scaler)
b_scaler = self.tanh(b_scaler)
# combine the global offsets and scale factors with individual ones
alpha = self.alpha_offset + (self.alpha_scale * a_scaler)
beta = self.beta_offset + (self.beta_scale * b_scaler)
# put alpha on positive range with exp, beta with softplus
alpha = tt.exp(alpha)
beta = self.softplus(beta)
return alpha, beta
Essentially why this helps is that the tanh
activation function enforces the individual/observation scaling factors to always be between -1 and 1, so you don't have to worry about too small or large outputs from your network. The alpha_scale
and beta_scale
are responsible for setting the range to multiply the -1 to 1 outputs by. The offsets are nice as an intercept or centering mechanism.
If you set the initialization for the offsets and scaling factors to be low numbers (I start them at 1.0, for example), they will slowly creep up to their optimal values during fit. Here is some output from a recent fit of mine to show what I mean:
A off: 1.10000 A scale: 1.10000 B off: 0.90000 B scale: 1.10000
A off: 1.23279 A scale: 1.03885 B off: 0.90022 B scale: 0.89804
A off: 1.25786 A scale: 1.06547 B off: 0.90056 B scale: 0.89798
A off: 1.28466 A scale: 1.09343 B off: 0.90163 B scale: 0.89878
A off: 1.34988 A scale: 1.16266 B off: 0.90678 B scale: 0.90290
A off: 1.44370 A scale: 1.25528 B off: 0.93324 B scale: 0.93015
A off: 1.53040 A scale: 1.32979 B off: 0.98308 B scale: 0.98226
...[many epochs later]...
A off: 1.92782 A scale: 1.57879 B off: 2.97086 B scale: 2.55308
A off: 1.93340 A scale: 1.59249 B off: 3.01380 B scale: 2.59065
A off: 1.94988 A scale: 1.59956 B off: 3.01739 B scale: 2.54407
A off: 1.94464 A scale: 1.59733 B off: 3.03923 B scale: 2.55807
A off: 1.95629 A scale: 1.60365 B off: 3.06733 B scale: 2.58807
A off: 1.95865 A scale: 1.59092 B off: 3.09355 B scale: 2.60830
You could also enforce maximums on alpha
and beta
easily if you wanted to by adding torch.clamp
calls around the outputs, but I have not found this to be necessary.
I have only tested this on my own data and so I can't make any claims that this will solve numerical instability issues for other people, but I figured it may help someone!