CNN-Visualization
CNN-Visualization copied to clipboard
Visualize using ckpt files
Hi @conan7882 I wanted to know how can we visualize custom trained vgg19 model in ckpt format. Thanks
you can (1) generate .pb file from checkpoint file
def generate_pb_file(model_dir,output_node_name,option = 'latest'):
'''
parameters:
model_dir: the output model directory
input_node_name: input name in graph
output_node_name: output node name in graph
option: 'latest': generate pb file from latest checkpoint
'min' : geberate pb file from minimum-validation error checkpont
'''
tf.reset_default_graph()
config = tf.ConfigProto(allow_soft_placement = True)
# one meta file in each saved dierctory
allfiles = os.listdir(model_dir)
pb_file_name = [s for s in allfiles if s.endswith('.meta')]
assert len(pb_file_name) == 1 ,'more than one meta file'
pb_file_name = pb_file_name[0]
meta_path = os.path.join(model_dir,pb_file_name)
with tf.Session(config = config) as sess:
# Restore the graph
# clear_divices: do not care which GPU to use
saver = tf.train.import_meta_graph(meta_path,clear_devices=True)
# count total number of parameters in the model
total_param_count = np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])
# Load weights
if option == 'latest': # restore from latest checkpoint
saver.restore(sess,tf.train.latest_checkpoint(model_dir))
output_name = 'output_latest.pb'
elif option == 'min': # restore from min validation error
saver.restore(sess, os.path.join(model_dir, 'min-validation_error'))
output_name = 'output_min.pb'
else:
import sys
sys.exit('Do not have the specified checkpoint file')
frozen_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
tf.get_default_graph().as_graph_def(),
output_node_name)
# generate corresponding file in the model checkpoint directory
graph_pb_path = os.path.join(model_dir,output_name)
# Save the frozen graph
with open(graph_pb_path, 'wb') as f:
f.write(frozen_graph_def.SerializeToString())
print('Save model pb file to path, ', graph_pb_path)
(2) load your graph
with tf.gfile.GFile(graph_pb_path, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name="")
(3) feed your target layer with pre-processed image
with tf.Session(graph=graph) as sess:
channel_value = sess.run(channel, feed_dict = {inputs: image_value})