probability icon indicating copy to clipboard operation
probability copied to clipboard

Cannot train tfb.Glow with tf.keras ?

Open khanx169 opened this issue 3 years ago • 1 comments

Hi, I am getting NaNs when I sample images after training the glow bijector.

Is there some compatibility issue between tf and tfp, or am I doing something wrong?

tensorflow.__version__:               2.6.0
tensorflow_probability.__version__:   0.14.1

See script below and output below:

# Load Data
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()
train_images = train_images.astype('float32') / 255
test_images = test_images.astype('float32') / 255

train_ds = tf.data.Dataset.from_tensor_slices((train_images[:2000], train_labels[:2000]))
test_ds = tf.data.Dataset.from_tensor_slices((test_images, test_labels))

batch_size = 128
train_ds = train_ds.batch(batch_size)
test_ds = test_ds.batch(batch_size)


# Define Glow Bijector
glow = tfp.bijectors.Glow(
    output_shape=(32, 32, 3),
    coupling_bijector_fn=tfb.GlowDefaultNetwork, 
    exit_bijector_fn=tfb.GlowDefaultExitNetwork,
)

## Base Distribution
z_shape = glow.inverse_event_shape( (32, 32, 3) )
pz = tfd.Sample(tfd.Normal(0., 1.), z_shape)

## Transformed Distribution
sigm = tfb.sigmoid.Sigmoid(low=-10, high=10)
output = tfb.chain.Chain([sigm, glow])
px = tfd.TransformedDistribution(distribution=pz,
                                bijector=output,)

# Train
optimizer=Adam(1e-4)
EPOCHS = 1

@tf.function(autograph=False, jit_compile=False)
def train_step(x):
    with tf.GradientTape() as tape:
        loss = -px.log_prob(x) 
    variables = tape.watched_variables()
    gradients = tape.gradient(loss, variables)
    optimizer.apply_gradients( zip(gradients, variables) )
    return loss


for epoch in range(EPOCHS):
    training_loss = 0
    for step, (x, y) in enumerate( train_ds ):
        l = train_step(x)
        training_loss += tf.reduce_mean(l)
        
        if step % 10 == 0:
            print(step, tf.reduce_mean(l))
    print(training_loss)


# Inference
images = px.sample(9) 
print(images[0]

Output:

0 tf.Tensor(8102.7705, shape=(), dtype=float32)
10 tf.Tensor(26.895271, shape=(), dtype=float32)
tf.Tensor(26641.363, shape=(), dtype=float32)

<tf.Tensor: shape=(32, 32, 3), dtype=float32, numpy=
array([[[nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        ...,
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan]],

       ...,

       [[nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        ...,
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan]]], dtype=float32)>

khanx169 avatar Oct 26 '21 18:10 khanx169

I am getting the same issue

ivallesp avatar Jan 21 '22 17:01 ivallesp