DeepExplain
DeepExplain copied to clipboard
Problem on recreating the graph with trained weights (TensorFlow)
I have trained model (1-lyr Bi-LSTM followed by 2-lyr FC). Following the MNIST example, I tried to "recreate the network graph" under the DeepExplain context.
The problem of the recreated logits = model(X)
is that it created a new graph of everything with a similar but different name as in the original graph. For example, the new (recreated) graph has a tensor Tensor("Prediction_1/predicted:0", shape=(?,), dtype=float32)
and its original counterpart is Tensor("Prediction/predicted:0", shape=(?,), dtype=float32)
. As a result, the session_run
didn't work and I guess it is because the weights restored for the original graph were not recognized in the recreated graph. How can I solve this problem?
Thanks.
Can you share the code you use to reload the weights?
I have a class A
with its __init__
to receive model parameters (e.g. numbers of LSTM and FC layers, number of n, dropout, etc.) and build the graph in it. The input tensor is extracted from a feedable tf.dataset pipeline, the predicted tensor is the result of the LSTM + FC. The run
member function of class A
gets the running mode (i.e. TRAIN
, EVAL
, PREDICT
, or ATTRIBUTION
) and run the graph defined in __init__
accordingly. For modes other than TRAIN
and EVAL
, the weights are restored by restore_checkpoint
:
def restore_checkpoint(self, sess): # restore the latest checkpoint status try: ckpt = tf.train.get_checkpoint_state(self.FP_checkpoints) if ckpt and ckpt.model_checkpoint_path: self.saver.restore(sess, ckpt.model_checkpoint_path) last_global_step = sess.run(self.global_step) return last_global_step else: raise Exception('Message: Checkpoint was not restored correctly.') except Exception as err: print(err.args) return -1
Currently, a workaround I found is to apply the graph.gradient_override_map(...)
in the original graph for the parts from extracted input tensor to the predicted tensor before restoring the trained weights. Besides, I have to prepare attribution tensor in my graph instead of calculating them in the get_symbolic_attribution
function for methods with [g * x]
attributions (not necessary if g
is the attribution because g
is somehow part of the original graph already). I know my workaround is messy and move substantial parts of the code in your code into the original graph, could you help improve this? Thanks.
The issue was solved by wrapping the graph components with tf.make_template function. I set the "name" property of the instance of the model class to a unique string, so that every time an instance's (with that unique "name") method is called, the Variables in that method will be reused without creating new ones. By the way, Tensors will still be duplicated with suffix if the method is called multiple times, but that doesn't affect the recovery of weights/biases from checkpoints (b/c they are Variables and they have unique names).
Reference: https://gist.github.com/danijar/720394a9071a03413be8a60852374aa4