probability icon indicating copy to clipboard operation
probability copied to clipboard

How to save and load trained trainable_variables in tensorflow probability distribution?

Open moonman925 opened this issue 4 years ago • 5 comments

I want to build a flow model which I can save the trained model.

flow = tfd.TransformedDistribution(
    ...
    distribution=tfd.Normal(loc=0.0, scale=1.0),
    bijector=my_bijectors
)

for e in epochs:
    ...
    with tf.GradientTape() as tape:
         log_prob_loss = loss() # some loss func here
    grads = tape.gradient(log_prob_loss, flow.trainable_variables)
    optimizer.apply_gradients(zip(grads, flow.trainable_variables))
    ...

How can I save and load the flow here?

I was thinking about using pickle to save the trainable_variables, but I don't know how to apply the loaded trained trainable_variables to the orignal one. Is there a way to update it?

moonman925 avatar Dec 07 '19 03:12 moonman925

I am having the same issue. Have searched far and wide and cannot find any way to save my TransformedDistribution.

Did anyone ever figure this out?

georgestein avatar Feb 11 '21 18:02 georgestein

Use tf.train.Checkpoint.

flow = tfd.TransformedDistribution(...)
ckpt = tf.train.Checkpoint(flow)
ckpt.save('.../yourpath/to/ckpt/...') # use restore to restore ckpt
# or 
ckpt.write('.../yourpath/to/ckpt/...') # use read to restore ckpt

ref: https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint?version=nightly#save

gitlabspy avatar Feb 12 '21 02:02 gitlabspy

Hello, I have the same problem. Did you manage this somehow? I would by grateful for any helpful info!

kaamka avatar Sep 23 '21 14:09 kaamka

Hi kaamka,

I managed to save and a load flow models a bit differently than gitlabspy above - although their answer probably works well too! Here is some pseudocode on how I got it working by saving the weights to a checkpoint during training. Then when I want to use the model I construct another flow with the same architecture, and load the weights into it.

def construct_normalizing_flow(params, optimizer=tf.optimizers.Adam(1e-3)):
    # Construct flow
    flow = tfd.TransformedDistribution(distribution=..., bijector=...)                                       
    
    # Construct flow model.                                                                      
    z_ = tfkl.Input(shape=(u_latent_dim,), dtype=tf.float32)
    log_prob_ = flow.log_prob(z_)
    
    model = tfk.Model(inputs=z_, outputs=log_prob_)
    model.compile(optimizer=optimizer, loss=lambda _, log_prob: -log_prob)

set up callback and train, which will save model every <save_freq> epochs

model, flow = construct_normalizing_flow(params, optimizer=optimizer)

checkpoint_filepath = 'flow_checkpoint'
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_filepath,
                                                         save_weights_only=True,
                                                         verbose=1,
                                                         save_freq=min(100, params['epochs_flow']))

model.fit(x=data,
              y=tf.zeros((data.shape[0], 0), dtype=tf.float32),
              batch_size=params['batch_size'],
              epochs=params['epochs_flow'],
              steps_per_epoch=data.shape[0]//params['batch_size'],
              shuffle=True,
              verbose=verbose,
              callbacks=[cp_callback])

Now that the model is fit, I can load it in by constructing another model with construct_normalizing_flow(), then loading the weights into it

model, flow = construct_normalizing_flow(params)
model.load_weights(checkpoint_filepath)

Hope this helps

georgestein avatar Sep 23 '21 16:09 georgestein

Hi @georgestein

I'm sorry for this late reply! I have completely forgotten to answer due to the tone of duties. However I'd like to THANK YOU very much for your help! Your solution worked for me (as well as the solution from @gitlabspy ). I was able to finally work with my Normalized Flows successfully.

Thank you guys again :) Happy coding!

kaamka avatar Jan 31 '22 19:01 kaamka