im2latex icon indicating copy to clipboard operation
im2latex copied to clipboard

how to make predict?

Open code-learner opened this issue 6 years ago • 1 comments

i have saved the model and want to use the model to predict. but i always got error. my predict code as follows:

vocab_to_idx = dict([(vocab[i], i) for i in range(len(vocab))])
formulas = open('data/formulas.norm.lst').read().split('\n')
test = random.sample(open('data/test.lst').read().split('\n')[:-1], 100)
def formula_to_indices(formula):
    formula = formula.split(' ')
    res = [0]
    for token in formula:
        if token in vocab_to_idx:
            res.append(vocab_to_idx[token] + 4)
        else:
            res.append(2)
    res.append(1)
    return res
def get_session():
    cfg = tf.ConfigProto(log_device_placement=True)
    cfg.gpu_options.allow_growth = True
    return tf.Session(config=cfg)
formulas = map(formula_to_indices, formulas)
def import_images(datum):
    datum = datum.split(' ')
    img = np.array(Image.open('image_data/' + datum[0]).convert('L'))
    img1 = np.expand_dims(np.expand_dims(img, 0), 3)
    return img1, np.expand_dims(formulas[int(datum[1])], 0)
test = map(import_images, test)
checkpoint_file = tf.train.latest_checkpoint("model/")
with get_session() as sess:
    # Load the saved meta graph and restore variables
    saver = tf.train.import_meta_graph("model/model.meta")
    saver.restore(sess, checkpoint_file)
    graph = tf.get_default_graph()
    # Get the placeholders from the graph by name
    input_x = graph.get_operation_by_name("input").outputs[0]
    num_rows = graph.get_operation_by_name("nr").outputs[0]
    num_columns = graph.get_operation_by_name("nc").outputs[0]
    num_words = graph.get_operation_by_name("nw").outputs[0]
    # Tensors we want to evaluate
    predictions = graph.get_operation_by_name("predict")
    # Generate batches for one epoch
    for i in range(len(test)):
        images, labels = test[i]
        predict = sess.run(predictions, {input_x: images,
                                                   num_rows: images.shape[1],
                                                   num_columns: images.shape[2],
                                                   num_words: labels.shape[1]})
        print(predict)

i got the error as follows:

       Caused by op u'map/while/fun/encoder_rnn/bidirectional_rnn/fw/fw/Assert/Assert', defined at:
  File "predict.py", line 47, in <module>
    saver = tf.train.import_meta_graph("model/model.meta")
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 1939, in import_meta_graph
    **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/meta_graph.py", line 744, in import_scoped_meta_graph
    producer_op_list=producer_op_list)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/util/deprecation.py", line 454, in new_func
    return func(*args, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/importer.py", line 442, in import_graph_def
    _ProcessNewOps(graph)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/importer.py", line 234, in _ProcessNewOps
    for new_op in graph._add_new_tf_operations(compute_devices=False):  # pylint: disable=protected-access
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 3289, in _add_new_tf_operations
    for c_op in c_api_util.new_tf_operations(self)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 3180, in _create_op_from_tf_operation
    ret = Operation(c_op, self)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1717, in __init__
    self._traceback = tf_stack.extract_stack()

InvalidArgumentError (see above for traceback): assertion failed: [Expected shape for Tensor map/while/fun/encoder_rnn/bidirectional_rnn/fw/fw/sequence_length:0 is ] [1] [ but saw shape: ] [512]
	 [[Node: map/while/fun/encoder_rnn/bidirectional_rnn/fw/fw/Assert/Assert = Assert[T=[DT_STRING, DT_INT32, DT_STRING, DT_INT32], summarize=3, _device="/job:localhost/replica:0/task:0/device:CPU:0"](map/while/fun/encoder_rnn/bidirectional_rnn/fw/fw/All/_303, map/while/fun/encoder_rnn/bidirectional_rnn/fw/fw/Assert/Assert/data_0, map/while/fun/encoder_rnn/bidirectional_rnn/fw/fw/stack/_305, map/while/fun/encoder_rnn/bidirectional_rnn/fw/fw/Assert/Assert/data_2, map/while/fun/encoder_rnn/bidirectional_rnn/fw/fw/Shape_1/_277, ^_cloopmap/while/fun/encoder_rnn/bidirectional_rnn/fw/fw/TensorArrayStack/range/start/_221)]]
	 [[Node: map/while/fun/encoder_rnn/bidirectional_rnn/fw/fw/Assert/Assert/_312 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_440_map/while/fun/encoder_rnn/bidirectional_rnn/fw/fw/Assert/Assert", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"](^_cloopmap/while/fun/encoder_rnn/bidirectional_rnn/fw/fw/TensorArrayUnstack/strided_slice/_226)]]

my batch_size is 512, and when construct the net, batch_size is a parameter, did the error caused by this?

code-learner avatar Apr 01 '19 08:04 code-learner

it should be same as when you trained the model.Hope the code below can help you img = Image.open('data/images_processed/7944775fc9.png').convert('L') img = np.array(img) img = np.expand_dims(img,3) img = np.array([img] * 20) model_files = tf.train.latest_checkpoint('saved_models/model-17-04-2019--17-32/') with tf.Session() as sess: saver.restore(sess,model_files) ans = output.eval(feed_dict={inp:img,num_rows:img.shape[1],num_columns:img.shape[2],num_words:32}) one = tf.argmax(ans,2) result = one.eval() print(result[0])

tomsay avatar Apr 24 '19 09:04 tomsay