DeepExplain icon indicating copy to clipboard operation
DeepExplain copied to clipboard

Problem on recreating the graph with trained weights (TensorFlow)

Open maosi-chen opened this issue 6 years ago • 3 comments

I have trained model (1-lyr Bi-LSTM followed by 2-lyr FC). Following the MNIST example, I tried to "recreate the network graph" under the DeepExplain context.

The problem of the recreated logits = model(X) is that it created a new graph of everything with a similar but different name as in the original graph. For example, the new (recreated) graph has a tensor Tensor("Prediction_1/predicted:0", shape=(?,), dtype=float32) and its original counterpart is Tensor("Prediction/predicted:0", shape=(?,), dtype=float32). As a result, the session_run didn't work and I guess it is because the weights restored for the original graph were not recognized in the recreated graph. How can I solve this problem?

Thanks.

maosi-chen avatar Jul 08 '18 21:07 maosi-chen

Can you share the code you use to reload the weights?

marcoancona avatar Jul 16 '18 13:07 marcoancona

I have a class A with its __init__ to receive model parameters (e.g. numbers of LSTM and FC layers, number of n, dropout, etc.) and build the graph in it. The input tensor is extracted from a feedable tf.dataset pipeline, the predicted tensor is the result of the LSTM + FC. The run member function of class A gets the running mode (i.e. TRAIN, EVAL, PREDICT, or ATTRIBUTION ) and run the graph defined in __init__ accordingly. For modes other than TRAIN and EVAL, the weights are restored by restore_checkpoint:

def restore_checkpoint(self, sess):
    # restore the latest checkpoint status
    try:
        ckpt = tf.train.get_checkpoint_state(self.FP_checkpoints)
        if ckpt and ckpt.model_checkpoint_path:
            self.saver.restore(sess, ckpt.model_checkpoint_path)
            last_global_step = sess.run(self.global_step)
            return last_global_step
        else:
            raise Exception('Message: Checkpoint was not restored correctly.')
    except Exception as err:
        print(err.args)
        return -1

Currently, a workaround I found is to apply the graph.gradient_override_map(...) in the original graph for the parts from extracted input tensor to the predicted tensor before restoring the trained weights. Besides, I have to prepare attribution tensor in my graph instead of calculating them in the get_symbolic_attribution function for methods with [g * x] attributions (not necessary if g is the attribution because g is somehow part of the original graph already). I know my workaround is messy and move substantial parts of the code in your code into the original graph, could you help improve this? Thanks.

maosi-chen avatar Jul 16 '18 17:07 maosi-chen

The issue was solved by wrapping the graph components with tf.make_template function. I set the "name" property of the instance of the model class to a unique string, so that every time an instance's (with that unique "name") method is called, the Variables in that method will be reused without creating new ones. By the way, Tensors will still be duplicated with suffix if the method is called multiple times, but that doesn't affect the recovery of weights/biases from checkpoints (b/c they are Variables and they have unique names).

Reference: https://gist.github.com/danijar/720394a9071a03413be8a60852374aa4

maosi-chen avatar Nov 17 '18 20:11 maosi-chen