GazeML icon indicating copy to clipboard operation
GazeML copied to clipboard

Failed to load the frozen graph of GazeML

Open parai opened this issue 5 years ago • 6 comments

Hi: With below code added to GazeML/src/core/model.py inference_generator, I can successfully export the GazeML to a frozen graph gaze.pb with the weights loaded by saver:

    sess = self._tensorflow_session
    from tensorflow.python.framework import graph_util
    constant_graph = graph_util.convert_variables_to_constants(
            sess, sess.graph_def,
            ['hourglass/hg_2/after/hmap/conv/BiasAdd', # heatmaps
             'upscale/mul', # landmarks
             'radius/out/fc/BiasAdd', # radius
             'Webcam/fifo_queue_DequeueMany', # frame_index, eye, eye_index
            ])
    with tf.gfile.FastGFile('./gaze.pb', mode='wb') as f:
        f.write(constant_graph.SerializeToString())

Then I try to load it with below code:

    with tf.gfile.FastGFile('gaze.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        _ = tf.import_graph_def(graph_def, name='')

But it will be loaded failed with below errors:

2019-05-16 22:49:42.059659: I T:\src\github\tensorflow\tensorflow\core\platform\cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
Traceback (most recent call last):
  File "C:\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\importer.py", line 489, in import_graph_def
    graph._c_graph, serialized, options)  # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input 0 of node radius/fc3/BatchNorm/cond_1/AssignMovingAvg_1/Switch was passed float from radius/fc3/BatchNorm/moving_variance:0 incompatible with expected float_ref.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "gaze.py", line 47, in <module>
    eye,eye_index,frame_index,landmarks,radius = model()
  File "gaze.py", line 37, in model
    _ = tf.import_graph_def(graph_def, name='')
  File "C:\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\util\deprecation.py", line 432, in new_func
    return func(*args, **kwargs)
  File "C:\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\importer.py", line 493, in import_graph_def
    raise ValueError(str(e))
ValueError: Input 0 of node radius/fc3/BatchNorm/cond_1/AssignMovingAvg_1/Switch was passed float from radius/fc3/BatchNorm/moving_variance:0 incompatible with expected float_ref.

I googled it and find this page, https://stackoverflow.com/questions/34265768/what-is-a-tensorflow-float-ref, according to its comments, that there is maybe a use of tf.Variable instead of a t.fplaceholder, thus this issue happened.

But I am not familiar with the GazeML code, could you help to fix this issue, really thanks?

parai avatar May 16 '19 15:05 parai

@parai I am getting the same issue, did you find a solution?

funkfuzz avatar Mar 11 '20 15:03 funkfuzz

Hi I found solution. This middle part code block starting with fixed nodes, solves the problem.

def model():
    print("Trying to import Gaze Model.")
    dir = os.path.dirname(os.path.realpath(__file__))+'/gaze'
    pb = glob.glob('%s/*.pb'%(dir))[0]

    # Read graph definition
    with tf.gfile.FastGFile(pb, 'rb') as f:
        gd = graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

	# Fix nodes of freezed model
        for node in graph_def.node:
            if node.op == 'RefSwitch':
                node.op = 'Switch'
                for index in range(len(node.input)):
                    if 'moving_' in node.input[index]:
                        node.input[index] = node.input[index] + '/read'
            elif node.op == 'AssignSub':
                node.op = 'Sub'
                if 'use_locking' in node.attr: del node.attr['use_locking']

        # Export fixed freezed model pb file.
        with tf.gfile.FastGFile('./gaze_better.pb', mode='wb') as model_fixed:
           model_fixed.write(graph_def.SerializeToString())

        # Import graph into session
        tf.import_graph_def(graph_def, name='')
  
    # Saving the important nodes of the Gaze model
    # Input nodes of the model
    frame_index = sess.graph.get_tensor_by_name('Webcam/fifo_queue_DequeueMany:0')
    eye = sess.graph.get_tensor_by_name('Webcam/fifo_queue_DequeueMany:1')
    eye_index = sess.graph.get_tensor_by_name('Webcam/fifo_queue_DequeueMany:2')
    
    # Output nodes of the model
    heatmaps = sess.graph.get_tensor_by_name('hourglass/hg_2/after/hmap/conv/BiasAdd:0')
    landmarks = sess.graph.get_tensor_by_name('upscale/mul:0')
    radius = sess.graph.get_tensor_by_name('radius/out/fc/BiasAdd:0')
    sess.run(tf.global_variables_initializer())
    return eye,heatmaps,landmarks,radius

Zeleni9 avatar Apr 21 '20 09:04 Zeleni9

@Zeleni9 @parai has anyone of you managed to export a .pb model that can be succesfully imported in OpenCV with readNetFromTensorFlow()?

funkfuzz avatar Apr 29 '20 10:04 funkfuzz

I don't know about OpenCV, but I used command above to get freezed model and load it in tensorflow for inference.

with tf.gfile.FastGFile('./gaze_better.pb', mode='wb') as model_fixed: model_fixed.write(graph_def.SerializeToString())

Zeleni9 avatar Apr 29 '20 13:04 Zeleni9

@Zeleni9, how did you manage to export a frozen model with a node 'Webcam/fifo_queue_DequeueMany' ? When I try to do it, by first loading the metagraph from the checkpoints and then using the code from @parai, tensorflow tells me that there is no node named 'Webcam/fifo_queue_DequeueMany'. However if I use 'UnityEyes/random_shuffle_queue_DequeueMany' it manages to export it fine. Is it just a typing mistake or am I missing something?

Here is my code:

# This function exports a saved model

import os
import tensorflow as tf
from tensorflow.python.framework import graph_util

# trained_checkpoint_prefix = 'checkpoints/dev'
trained_checkpoint_prefix = 'model-4672654'
export_dir = os.path.join('models', 'GazeML_010520') # IMPORTANT: each model folder must be named '0', '1', ... Otherwise it will fail!

# handle unitialized variables
def initialize_uninitialized(sess):
    global_vars          = tf.compat.v1.global_variables()
    is_not_initialized   = sess.run([tf.compat.v1.is_variable_initialized(var) for var in global_vars])
    not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]

    print ([str(i.name) for i in not_initialized_vars]) # only for testing
    if len(not_initialized_vars):
        sess.run(tf.compat.v1.variables_initializer(not_initialized_vars))

loaded_graph = tf.Graph()
with tf.compat.v1.Session(graph=loaded_graph) as sess:
    # Restore from checkpoint
    loader = tf.compat.v1.train.import_meta_graph(trained_checkpoint_prefix + '.meta')
    loader.restore(sess, trained_checkpoint_prefix)
    initialize_uninitialized(sess)
    
    constant_graph = graph_util.convert_variables_to_constants(
            sess, sess.graph_def,
            ['hourglass/hg_2/after/hmap/conv/BiasAdd', # heatmaps
             'upscale/Mean', # landmarks
             'radius/out/fc/BiasAdd', # radius
             'UnityEyes/random_shuffle_queue_DequeueMany', # frame_index, eye, eye_index
            ])
    with tf.gfile.FastGFile('./saved_GazeML.pb', mode='wb') as f:
        f.write(constant_graph.SerializeToString())

any help will be much appreciated! :)

funkfuzz avatar May 01 '20 14:05 funkfuzz

Well I have added commented part of the code from this answer to issue - https://github.com/swook/GazeML/issues/23#issuecomment-494861909 into the code in file: https://github.com/swook/GazeML/blob/master/src/core/model.py inside at the end of method def inference_generator(self). It is the code on line 9-19 on this link https://github.com/parai/dms/blob/master/models/gaze.py.

The idea is that exports the model on inference_generator call, but it is only sufficient to save it once and stop the inference_generator.

Here I tried both versions for webcam or video so the names were 'Webcam/fifo_queue_DequeueMany' or 'Video/fifo_queue_DequeueMany', the graph will export with name you put in the code. Hope this helps.

Zeleni9 avatar May 01 '20 16:05 Zeleni9