probability
probability copied to clipboard
Cannot train tfb.Glow with tf.keras ?
Hi, I am getting NaNs when I sample images after training the glow bijector.
Is there some compatibility issue between tf and tfp, or am I doing something wrong?
tensorflow.__version__: 2.6.0
tensorflow_probability.__version__: 0.14.1
See script below and output below:
# Load Data
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()
train_images = train_images.astype('float32') / 255
test_images = test_images.astype('float32') / 255
train_ds = tf.data.Dataset.from_tensor_slices((train_images[:2000], train_labels[:2000]))
test_ds = tf.data.Dataset.from_tensor_slices((test_images, test_labels))
batch_size = 128
train_ds = train_ds.batch(batch_size)
test_ds = test_ds.batch(batch_size)
# Define Glow Bijector
glow = tfp.bijectors.Glow(
output_shape=(32, 32, 3),
coupling_bijector_fn=tfb.GlowDefaultNetwork,
exit_bijector_fn=tfb.GlowDefaultExitNetwork,
)
## Base Distribution
z_shape = glow.inverse_event_shape( (32, 32, 3) )
pz = tfd.Sample(tfd.Normal(0., 1.), z_shape)
## Transformed Distribution
sigm = tfb.sigmoid.Sigmoid(low=-10, high=10)
output = tfb.chain.Chain([sigm, glow])
px = tfd.TransformedDistribution(distribution=pz,
bijector=output,)
# Train
optimizer=Adam(1e-4)
EPOCHS = 1
@tf.function(autograph=False, jit_compile=False)
def train_step(x):
with tf.GradientTape() as tape:
loss = -px.log_prob(x)
variables = tape.watched_variables()
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients( zip(gradients, variables) )
return loss
for epoch in range(EPOCHS):
training_loss = 0
for step, (x, y) in enumerate( train_ds ):
l = train_step(x)
training_loss += tf.reduce_mean(l)
if step % 10 == 0:
print(step, tf.reduce_mean(l))
print(training_loss)
# Inference
images = px.sample(9)
print(images[0]
Output:
0 tf.Tensor(8102.7705, shape=(), dtype=float32)
10 tf.Tensor(26.895271, shape=(), dtype=float32)
tf.Tensor(26641.363, shape=(), dtype=float32)
<tf.Tensor: shape=(32, 32, 3), dtype=float32, numpy=
array([[[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan],
...,
[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]],
...,
[[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan],
...,
[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]]], dtype=float32)>
I am getting the same issue