Output is Nan when stacking multiple MAFs but normal when using only 1 MAF
Hi, I am having troubles understanding why I am getting nan when stacking a bunch of conditional MAFs but I am not getting a nan with a single conditional MAF. I have the following dataset: A tuple of parameters x and signal data S.
x = [a,f,p] where a is in the range (0.2,1) coming from uniform distribution, f is in the range (0.1,0.25) coming from uniform distribution, and p is in the range (0,2pi) also coming from uniform distribution. The context (signal) is a function in time of a * sin(2*pi*f + p). The time steps are 24 so an example input would be
[a,f,p], [s1....s24]
And the problem is solved with a single MAF, however, I need to be able to scale it up for harder problems and so I am testing having 3 MAFs and permuting the dimensions between them.
My NF is as follows
dims = 3
context_dim = 24
permute = tfp.bijectors.Permute(permutation=[1, 2, 0])
mafblock1 =tfb.MaskedAutoregressiveFlow(
shift_and_log_scale_fn=tfb.AutoregressiveNetwork(
params=2,
event_shape=(dims),
hidden_units=[128,128],
conditional=True,
conditional_event_shape=(context_dim,),
activation='relu'
),name='maf1'
)
mafblock2 =tfb.MaskedAutoregressiveFlow(
shift_and_log_scale_fn=tfb.AutoregressiveNetwork(
params=2,
event_shape=(dims,),
hidden_units=[128,128],
conditional=True,
conditional_event_shape=(context_dim,),
activation='relu'
),name='maf2'
)
mafblock3 =tfb.MaskedAutoregressiveFlow(
shift_and_log_scale_fn=tfb.AutoregressiveNetwork(
params=2,
event_shape=(dims,),
hidden_units=[128,128],
conditional=True,
conditional_event_shape=(context_dim,),
activation='relu'
),name='maf3'
)
maf = tfb.Chain([mafblock1,permute,mafblock2,permute,mafblock3,permute])
prior = tfd.Sample(tfd.Normal(loc=0,scale=1),sample_shape=[dims])
NormalizingFlow = tfp.distributions.TransformedDistribution(
distribution = prior,
bijector = maf,
name = "NormalizingFlow"
)
I am optimizing with the following function
def train_density_estimation(distribution, optimizer, batch):
S=batch[0]
x = batch[1]
with tf.GradientTape() as tape:
tape.watch(distribution.trainable_variables)
log_prob=distribution.log_prob(x,bijector_kwargs=make_bijector_kwargs(
distribution.bijector,{'maf.': {'conditional_input': S},})
)
loss = -tf.reduce_mean(log_prob) # negative log likelihood
gradients = tape.gradient(loss, distribution.trainable_variables)
optimizer.apply_gradients(zip(gradients, distribution.trainable_variables))
return loss
train_epochs = 100
train_loss = []
for epoch in range(train_epochs):
epoch_loss = 0
for (strains,params) in train_data:
loss = train_density_estimation(NormalizingFlow,optimizer,(strains,params))
epoch_loss += loss
print("Train loss for that epoch is: {} ".format(epoch), epoch_loss/(train_num_samples/batch_size))
train_loss.append((epoch_loss*batch_size)/train_num_samples)
Does someone see why this is happening?