PocketFlow
PocketFlow copied to clipboard
ChannelPrunedLearner error: list index out of range
error info:
Traceback (most recent call last):
File "/home//git/pocket_flow/compress_model_run.py", line 100, in compress_model
learner = create_learner(sm_writer, model_helper)
File "/home/git/pocket_flow/learners/learner_utils.py", line 49, in create_learner
learner = ChannelPrunedLearner(sm_writer, model_helper)
File "/home/git/pocket_flow/learners/channel_pruning/learner.py", line 120, in __init__
self.__build(is_train=False)
File "/home/git/pocket_flow/learners/channel_pruning/learner.py", line 220, in __build
self.__build_pruned_evaluate_model()
File "/home/git/pocket_flow/learners/channel_pruning/learner.py", line 294, in __build_pruned_evaluate_model
eval_logits = tf.get_collection('logits')[0]
IndexError: list index out of range
Because my custom model don't have the collection logits
. I read the code in channel_pruning/learner.py
, when mode is train, the __build
funciton will add logits
to colleciton. But when mode is eval, the __build_pruned_evaluate_model
will restore the model from save_path
, so the add logits
to collection operation dose not take effect.
To use your own model, you need to add a collection "logits" in the pre-trained model.
The __build_pruned_evaluate_model()
function restores a model from FLAGS.save_path
, which is the model saving path of FullPrecLearner
's training graph.
What is the purpose of the __build
funciton add logits
to colleciton ?
It is used to locate input and output tensors when exporting a compressed model to *.pb and *.tflite models.
error info:
Traceback (most recent call last):
File "/home/zgz/git/autovision/autovision/pocket_flow/compress_model_run.py", line 104, in compress_model
learner = create_learner(sm_writer, model_helper)
File "/home/zgz/git/autovision/autovision/pocket_flow/learners/learner_utils.py", line 49, in create_learner
learner = ChannelPrunedLearner(sm_writer, model_helper)
File "/home/zgz/git/autovision/autovision/pocket_flow/learners/channel_pruning/learner.py", line 120, in __init__
self.__build(is_train=False)
File "/home/zgz/git/autovision/autovision/pocket_flow/learners/channel_pruning/learner.py", line 220, in __build
self.__build_pruned_evaluate_model()
File "/home/zgz/git/autovision/autovision/pocket_flow/learners/channel_pruning/learner.py", line 299, in __build_pruned_evaluate_model
mem_images = tf.get_collection('mem_images')[0]
IndexError: list index out of range
my custom model use dataset, rather than the placeholder, so that my model don't have the collection save mem_images and mem_labels.
Why not use the model saved in cp_original_path
when exec __build_pruned_evaluate_model
function?
@psyyz10 Could you take a look at this issue?
Now, The batch size of train and eval is different in my custom model, that the batch size of train is 128 and the batch size of eval is 1. In mem_images
collection, there are two tensor, the first tensor's shape is identical with train.
error info:
Traceback (most recent call last):
File "/home/zgz/git/autovision/autovision/pocket_flow/compress_model_run.py", line 104, in compress_model
learner = create_learner(sm_writer, model_helper)
File "/home/zgz/git/autovision/autovision/pocket_flow/learners/learner_utils.py", line 49, in create_learner
learner = ChannelPrunedLearner(sm_writer, model_helper)
File "/home/zgz/git/autovision/autovision/pocket_flow/learners/channel_pruning/learner.py", line 120, in __init__
self.__build(is_train=False)
File "/home/zgz/git/autovision/autovision/pocket_flow/learners/channel_pruning/learner.py", line 220, in __build
self.__build_pruned_evaluate_model()
File "/home/zgz/git/autovision/autovision/pocket_flow/learners/channel_pruning/learner.py", line 304, in __build_pruned_evaluate_model
graph_editor.reroute_ts(eval_images, mem_images)
File "/home/zgz/anaconda2/envs/tf-1.8-cp3/lib/python3.6/site-packages/tensorflow/contrib/graph_editor/reroute.py", line 256, in reroute_ts
return _reroute_ts(ts0, ts1, _RerouteMode.a2b, can_modify, cannot_modify)
File "/home/zgz/anaconda2/envs/tf-1.8-cp3/lib/python3.6/site-packages/tensorflow/contrib/graph_editor/reroute.py", line 183, in _reroute_ts
_check_ts_compatibility(ts0, ts1)
File "/home/zgz/anaconda2/envs/tf-1.8-cp3/lib/python3.6/site-packages/tensorflow/contrib/graph_editor/reroute.py", line 66, in _check_ts_compatibility
shape1))
ValueError: Shapes (1, 224, 224, 3) and (128, 224, 224, 3) are not compatible.