CNN-Visualization icon indicating copy to clipboard operation
CNN-Visualization copied to clipboard

Visualize using ckpt files

Open harsh-agar opened this issue 6 years ago • 1 comments

Hi @conan7882 I wanted to know how can we visualize custom trained vgg19 model in ckpt format. Thanks

harsh-agar avatar Oct 04 '18 09:10 harsh-agar

you can (1) generate .pb file from checkpoint file

def generate_pb_file(model_dir,output_node_name,option = 'latest'):
    '''
    parameters: 
    model_dir: the output model directory
    input_node_name: input name in graph
    output_node_name: output node name in graph
    option: 'latest': generate pb file from latest checkpoint
            'min'   : geberate pb file from minimum-validation error checkpont

    '''    
    tf.reset_default_graph()
    config = tf.ConfigProto(allow_soft_placement = True)
    

    # one meta file in each saved dierctory
    allfiles = os.listdir(model_dir)
    pb_file_name = [s for s in allfiles if s.endswith('.meta')]
    assert len(pb_file_name) == 1 ,'more than one meta file'
    pb_file_name = pb_file_name[0]
    meta_path = os.path.join(model_dir,pb_file_name)

    with tf.Session(config = config) as sess:

        # Restore the graph
        # clear_divices: do not care which GPU to use
        saver = tf.train.import_meta_graph(meta_path,clear_devices=True)
        
        # count total number of parameters in the model
        total_param_count = np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])
 
        # Load weights
        if option == 'latest': # restore from latest checkpoint
           saver.restore(sess,tf.train.latest_checkpoint(model_dir))
           output_name = 'output_latest.pb'
        elif option == 'min':  # restore from min validation error
             saver.restore(sess, os.path.join(model_dir, 'min-validation_error'))
             output_name = 'output_min.pb'
        else:
             import sys
             sys.exit('Do not have the specified checkpoint file')
        frozen_graph_def = tf.graph_util.convert_variables_to_constants(
            sess,
            tf.get_default_graph().as_graph_def(),
            output_node_name)
        
        # generate corresponding file in the model checkpoint directory
        graph_pb_path = os.path.join(model_dir,output_name)
        # Save the frozen graph
        with open(graph_pb_path, 'wb') as f:
             f.write(frozen_graph_def.SerializeToString())

    print('Save model pb file to path, ', graph_pb_path)

(2) load your graph

    with tf.gfile.GFile(graph_pb_path, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name="")

(3) feed your target layer with pre-processed image

with tf.Session(graph=graph) as sess:
     channel_value = sess.run(channel, feed_dict = {inputs: image_value})

Sirius083 avatar Apr 01 '19 01:04 Sirius083