PocketFlow icon indicating copy to clipboard operation
PocketFlow copied to clipboard

ChannelPrunedLearner error: list index out of range

Open as754770178 opened this issue 6 years ago • 6 comments

error info:

Traceback (most recent call last):
  File "/home//git/pocket_flow/compress_model_run.py", line 100, in compress_model
    learner = create_learner(sm_writer, model_helper)
  File "/home/git/pocket_flow/learners/learner_utils.py", line 49, in create_learner
    learner = ChannelPrunedLearner(sm_writer, model_helper)
  File "/home/git/pocket_flow/learners/channel_pruning/learner.py", line 120, in __init__
    self.__build(is_train=False)
  File "/home/git/pocket_flow/learners/channel_pruning/learner.py", line 220, in __build
    self.__build_pruned_evaluate_model()
  File "/home/git/pocket_flow/learners/channel_pruning/learner.py", line 294, in __build_pruned_evaluate_model
    eval_logits = tf.get_collection('logits')[0]
IndexError: list index out of range

Because my custom model don't have the collection logits. I read the code in channel_pruning/learner.py, when mode is train, the __build funciton will add logits to colleciton. But when mode is eval, the __build_pruned_evaluate_model will restore the model from save_path, so the add logits to collection operation dose not take effect.

as754770178 avatar Dec 03 '18 12:12 as754770178

To use your own model, you need to add a collection "logits" in the pre-trained model. The __build_pruned_evaluate_model() function restores a model from FLAGS.save_path, which is the model saving path of FullPrecLearner's training graph.

jiaxiang-wu avatar Dec 04 '18 01:12 jiaxiang-wu

What is the purpose of the __build funciton add logits to colleciton ?

as754770178 avatar Dec 04 '18 03:12 as754770178

It is used to locate input and output tensors when exporting a compressed model to *.pb and *.tflite models.

jiaxiang-wu avatar Dec 04 '18 03:12 jiaxiang-wu

error info:

Traceback (most recent call last):
  File "/home/zgz/git/autovision/autovision/pocket_flow/compress_model_run.py", line 104, in compress_model
    learner = create_learner(sm_writer, model_helper)
  File "/home/zgz/git/autovision/autovision/pocket_flow/learners/learner_utils.py", line 49, in create_learner
    learner = ChannelPrunedLearner(sm_writer, model_helper)
  File "/home/zgz/git/autovision/autovision/pocket_flow/learners/channel_pruning/learner.py", line 120, in __init__
    self.__build(is_train=False)
  File "/home/zgz/git/autovision/autovision/pocket_flow/learners/channel_pruning/learner.py", line 220, in __build
    self.__build_pruned_evaluate_model()
  File "/home/zgz/git/autovision/autovision/pocket_flow/learners/channel_pruning/learner.py", line 299, in __build_pruned_evaluate_model
    mem_images = tf.get_collection('mem_images')[0]
IndexError: list index out of range

my custom model use dataset, rather than the placeholder, so that my model don't have the collection save mem_images and mem_labels.

Why not use the model saved in cp_original_path when exec __build_pruned_evaluate_model function?

as754770178 avatar Dec 06 '18 11:12 as754770178

@psyyz10 Could you take a look at this issue?

jiaxiang-wu avatar Dec 06 '18 11:12 jiaxiang-wu

Now, The batch size of train and eval is different in my custom model, that the batch size of train is 128 and the batch size of eval is 1. In mem_images collection, there are two tensor, the first tensor's shape is identical with train.

error info:

Traceback (most recent call last):
  File "/home/zgz/git/autovision/autovision/pocket_flow/compress_model_run.py", line 104, in compress_model
    learner = create_learner(sm_writer, model_helper)
  File "/home/zgz/git/autovision/autovision/pocket_flow/learners/learner_utils.py", line 49, in create_learner
    learner = ChannelPrunedLearner(sm_writer, model_helper)
  File "/home/zgz/git/autovision/autovision/pocket_flow/learners/channel_pruning/learner.py", line 120, in __init__
    self.__build(is_train=False)
  File "/home/zgz/git/autovision/autovision/pocket_flow/learners/channel_pruning/learner.py", line 220, in __build
    self.__build_pruned_evaluate_model()
  File "/home/zgz/git/autovision/autovision/pocket_flow/learners/channel_pruning/learner.py", line 304, in __build_pruned_evaluate_model
    graph_editor.reroute_ts(eval_images, mem_images)
  File "/home/zgz/anaconda2/envs/tf-1.8-cp3/lib/python3.6/site-packages/tensorflow/contrib/graph_editor/reroute.py", line 256, in reroute_ts
    return _reroute_ts(ts0, ts1, _RerouteMode.a2b, can_modify, cannot_modify)
  File "/home/zgz/anaconda2/envs/tf-1.8-cp3/lib/python3.6/site-packages/tensorflow/contrib/graph_editor/reroute.py", line 183, in _reroute_ts
    _check_ts_compatibility(ts0, ts1)
  File "/home/zgz/anaconda2/envs/tf-1.8-cp3/lib/python3.6/site-packages/tensorflow/contrib/graph_editor/reroute.py", line 66, in _check_ts_compatibility
    shape1))
ValueError: Shapes (1, 224, 224, 3) and (128, 224, 224, 3) are not compatible.

as754770178 avatar Dec 07 '18 03:12 as754770178